Untitled

mail@pastecode.io avatar
unknown
python
a year ago
490 kB
4
Indexable
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def divide2(A: T.Buffer((1, 1, 32001), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((1, 1, 32001), "float32")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(126, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_divide"):
                    v0 = T.axis.spatial(32001, ax0_fused_0 * 256 + ax0_fused_1)
                    T.where(ax0_fused_0 * 256 + ax0_fused_1 < 32001)
                    T.reads(A[0, 0, v0], B[()])
                    T.writes(T_divide[0, 0, v0])
                    T_divide[0, 0, v0] = A[0, 0, v0] / B[()]

    @T.prim_func(private=True)
    def extend_te(var_A: T.handle, var_concat_te: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, 1, n, n), "float16")
        m = T.int32()
        concat_te = T.match_buffer(var_concat_te, (1, 1, n, m), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((n * m + 255) // 256, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("concat_te"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % (m * n) // m)
                    v1 = T.axis.spatial(m, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % m)
                    T.where(ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 < n * m)
                    T.reads(A[0, 0, v0, v1 + (n - m)])
                    T.writes(concat_te[0, 0, v0, v1])
                    concat_te[0, 0, v0, v1] = T.if_then_else(v1 < m - n, T.float16(65504), A[0, 0, v0, v1 + (n - m)])

    @T.prim_func(private=True)
    def full(var_T_full: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        T_full = T.match_buffer(var_T_full, (1, 1, 1, n), "float16")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding((n + 255) // 256, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_full"):
                    v0 = T.axis.spatial(n, ax0_fused_0 * 256 + ax0_fused_1)
                    T.where(ax0_fused_0 * 256 + ax0_fused_1 < n)
                    T.reads()
                    T.writes(T_full[0, 0, 0, v0])
                    T_full[0, 0, 0, v0] = T.float16(65504)

    @T.prim_func(private=True)
    def fused_NT_matmul1_divide1_maximum1_minimum1_cast3(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv28 = T.match_buffer(p_lv28, (1, 32, n, 128), "float16")
        m = T.int32()
        lv29 = T.match_buffer(p_lv29, (1, 32, m, 128), "float16")
        lv5 = T.match_buffer(p_lv5, (1, 1, n, m), "float16")
        var_compute_intermediate = T.match_buffer(p_output0, (1, 32, n, m))
        # with T.block("root"):
        var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((32, (n + 31) // 32 * 32, (m + 63) // 64 * 64), "float16", scope="local")
        lv28_reindex_pad_shared = T.alloc_buffer((32, (n + 31) // 32 * 32, 128), "float16", scope="shared")
        lv29_reindex_pad_shared = T.alloc_buffer((32, (m + 63) // 64 * 64, 128), "float16", scope="shared")
        for ax0_ax2_0_fused in T.thread_binding((m + 63) // 64 * 32, thread="blockIdx.y"):
            for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                        for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                for ax2_3_init, ax1_3_init in T.grid(4, 4):
                                    with T.block("NT_matmul_init"):
                                        v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64))
                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                        v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
                                        T.reads()
                                        T.writes(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                        var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] = T.float16(0)
                                for ax3_0 in range(8):
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(2):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("lv28_reindex_pad_shared"):
                                                        v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64))
                                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(128, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv28[0, v0, v1, v2])
                                                        T.writes(lv28_reindex_pad_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        lv28_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv28[0, v0, v1, v2], T.float16(0))
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(4):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("lv29_reindex_pad_shared"):
                                                        v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64))
                                                        v1 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(128, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv29[0, v0, v1, v2])
                                                        T.writes(lv29_reindex_pad_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        lv29_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, lv29[0, v0, v1, v2], T.float16(0))
                                    for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
                                        with T.block("NT_matmul_update"):
                                            v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64))
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                            v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
                                            v3 = T.axis.reduce(128, ax3_0 * 16 + ax3_1)
                                            T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv28_reindex_pad_shared[v0, v1, v3], lv29_reindex_pad_shared[v0, v2, v3])
                                            T.writes(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                            var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv28_reindex_pad_shared[v0, v1, v3] * lv29_reindex_pad_shared[v0, v2, v3]
                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
                                    for ax2_1_1 in T.vectorized(2):
                                        with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
                                            v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64) + ax0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                            v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
                                            T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv5[0, 0, v1, v2])
                                            T.writes(var_compute_intermediate[0, v0, v1, v2])
                                            if v1 < n and v2 < m:
                                                var_compute_intermediate[0, v0, v1, v2] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.float16(0.088397790055248615), T.float16(-65504)), lv5[0, 0, v1, v2]))

    @T.prim_func(private=True)
    def fused_NT_matmul7_divide_maximum_minimum_cast(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16")
        lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16")
        var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n))
        # with T.block("root"):
        var_NT_matmul_intermediate_local = T.alloc_buffer((1, 32, 1, n), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 32, 1, n), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 32, 1, n), "float16", scope="local")
        lv1638_local = T.alloc_buffer((1, 32, n, 128), "float16", scope="local")
        lv1637_shared = T.alloc_buffer((1, 32, 1, 128), "float16", scope="shared")
        for ax0_fused_ax1_fused_fused_0 in T.thread_binding(n * 32, thread="blockIdx.x"):
            for ax0_fused_ax1_fused_fused_1 in T.thread_binding(1, thread="threadIdx.y"):
                for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.x"):
                    for ax0, ax1, ax2 in T.grid(1, 1, 1):
                        for ax3_0 in T.serial(1, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax3_1 in T.thread_binding(1, thread="threadIdx.y"):
                                for ax3_2 in T.thread_binding(64, thread="threadIdx.x"):
                                    for ax3_3 in T.vectorized(2):
                                        with T.block("lv1637_shared"):
                                            v0 = T.axis.spatial(1, ax0)
                                            v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1)
                                            v2 = T.axis.spatial(1, ax2)
                                            v3 = T.axis.spatial(128, ax3_0 * 128 + ax3_1 * 128 + ax3_2 * 2 + ax3_3)
                                            T.reads(lv1637[v0, v1, v2, v3])
                                            T.writes(lv1637_shared[v0, v1, v2, v3])
                                            lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3]
                    for ax0_fused_ax1_fused_fused_2_init in range(1):
                        for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(2):
                            with T.block("NT_matmul_rf_init"):
                                vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
                                v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) // n)
                                v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) % n)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1])
                                var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = T.float16(0)
                    for ax2_fused_u_fused_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2):
                            for ax2_1 in T.vectorized(1):
                                with T.block("lv1638_local"):
                                    v0 = T.axis.spatial(1, ax0)
                                    v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1)
                                    v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1)
                                    v3 = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3)
                                    T.reads(lv1638[v0, v1, v2, v3])
                                    T.writes(lv1638_local[v0, v1, v2, v3])
                                    lv1638_local[v0, v1, v2, v3] = lv1638[v0, v1, v2, v3]
                        for ax0_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(1, 1):
                            for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
                                    v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) // n)
                                    v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) % n)
                                    vax2_fused_u_fused_2, vax2_fused_u_fused_0 = T.axis.remap("RR", [ax2_fused_u_fused_2, ax2_fused_u_fused_0])
                                    T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1], lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused], lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused])
                                    T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1])
                                    var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] + lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused] * lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused]
            for ax2_ax3_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
                    for ax2_ax3_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_ax3_fused_1_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(64, ax0)
                                v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n)
                                v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1])
                                var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = T.float16(0)
                            for ax1 in range(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n)
                                    v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
                                    T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1], var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1])
                                    T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1])
                                    var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1]
            for ax1_ax2_fused_1 in range(1):
                for ax1_ax2_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(64, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(64, ax0)
                            v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n)
                            v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
                            T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1])
                            T.writes(var_NT_matmul_intermediate_local[0, v0, 0, v1])
                            with T.init():
                                var_NT_matmul_intermediate_local[0, v0, 0, v1] = T.float16(0)
                            var_NT_matmul_intermediate_local[0, v0, 0, v1] = var_NT_matmul_intermediate_local[0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]
            for ax0_ax1_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
                for ax0_ax1_fused_1 in range(1):
                    with T.block("compute"):
                        v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n)
                        v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
                        T.reads(var_NT_matmul_intermediate_local[0, v0, 0, v1], lv1614[0, 0, 0, v1])
                        T.writes(var_compute_intermediate[0, v0, 0, v1])
                        var_compute_intermediate[0, v0, 0, v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[0, v0, 0, v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1614[0, 0, 0, v1]))

    @T.prim_func(private=True)
    def fused_fused_decode1_fused_NT_matmul5_cast2(lv771: T.Buffer((32001, 512), "uint32"), lv772: T.Buffer((32001, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32001), "float32")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 32001), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 32001), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 32001), "float16", scope="local")
        lv771_local = T.alloc_buffer((32001, 512), "uint32", scope="local")
        lv3216_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(501, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(4):
                                        with T.block("lv3216_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
                                            T.reads(lv3216[v0, v1, v2])
                                            T.writes(lv3216_shared[v0, v1, v2])
                                            lv3216_shared[v0, v1, v2] = lv3216[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init < 32001)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_0, ax1 in T.grid(1, 1):
                            for ax0_1 in T.vectorized(1):
                                with T.block("lv771_local"):
                                    v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
                                    v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
                                    T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 < 32001)
                                    T.reads(lv771[v0, v1])
                                    T.writes(lv771_local[v0, v1])
                                    lv771_local[v0, v1] = lv771[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2 < 32001)
                                    T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv772[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
                                    T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
            for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_1_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
                                v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                T.where(u_fused_ax0_fused_fused_0 * 64 + (ax2_fused_0 + (ax2_fused_1_0 + ax2_fused_1_1)) < 32001)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
                            for ax1 in range(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                    T.where(u_fused_ax0_fused_fused_0 * 64 + (ax2_fused_0 + (ax2_fused_1_0 + ax2_fused_1_1)) < 32001)
                                    T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_1 in range(1):
                for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
                            v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
                            T.where(u_fused_ax0_fused_fused_0 * 64 + (ax1_fused_0 + ax1_fused_1) < 32001)
                            T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(var_NT_matmul_intermediate_local[0, 0, v0])
                            with T.init():
                                var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0)
                            var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
            for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                for ax0_fused_1 in range(1):
                    with T.block("compute"):
                        v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1)
                        T.where(u_fused_ax0_fused_fused_0 * 64 + (ax0_fused_0 + ax0_fused_1) < 32001)
                        T.reads(var_NT_matmul_intermediate_local[0, 0, v0])
                        T.writes(p_output0_intermediate[0, 0, v0])
                        p_output0_intermediate[0, 0, v0] = T.Cast("float32", var_NT_matmul_intermediate_local[0, 0, v0])

    @T.prim_func(private=True)
    def fused_fused_decode1_take(lv: T.Buffer((32001, 512), "uint32"), lv1: T.Buffer((32001, 128), "float16"), lv1611: T.Buffer((1,), "int32"), var_T_take_intermediate: T.Buffer((1, 4096), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_take"):
                    v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
                    T.reads(lv[lv1611[0], v0 // 8], lv1611[0], lv1[lv1611[0], v0 // 32])
                    T.writes(var_T_take_intermediate[0, v0])
                    var_T_take_intermediate[0, v0] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv[lv1611[0], v0 // 8], T.Cast("uint32", v0 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1[lv1611[0], v0 // 32]

    @T.prim_func(private=True)
    def fused_fused_decode1_take1(lv775: T.Buffer((32001, 512), "uint32"), lv776: T.Buffer((32001, 128), "float16"), p_lv: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv = T.match_buffer(p_lv, (n,), "int32")
        var_T_take_intermediate = T.match_buffer(p_output0, (n, 4096), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_take"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
                    v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
                    T.reads(lv775[lv[v0], v1 // 8], lv[v0], lv776[lv[v0], v1 // 32])
                    T.writes(var_T_take_intermediate[v0, v1])
                    var_T_take_intermediate[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv775[lv[v0], v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv776[lv[v0], v1 // 32]

    @T.prim_func(private=True)
    def fused_fused_decode2_NT_matmul(lv779: T.Buffer((12288, 512), "uint32"), lv780: T.Buffer((12288, 128), "float16"), p_lv6: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16")
        var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 12288), "float16")
        # with T.block("root"):
        var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 12288), "float16", scope="local")
        lv6_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared")
        p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 12288, 4096), "float16", scope="shared")
        for ax0_ax2_0_fused in T.thread_binding(192, thread="blockIdx.y"):
            for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                        for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                for ax2_3_init, ax1_3_init in T.grid(4, 4):
                                    with T.block("NT_matmul_init"):
                                        v0 = T.axis.spatial(1, 0)
                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                        v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
                                        T.reads()
                                        T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                        var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0)
                                for ax3_0 in range(256):
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(2):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("lv6_reindex_pad_shared"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv6[v0, v1, v2])
                                                        T.writes(lv6_reindex_pad_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        lv6_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv6[v0, v1, v2], T.float16(0))
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(4):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("p_output0_intermediate_reindex_shared"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv779[v1, v2 // 8], lv780[v1, v2 // 32])
                                                        T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv779[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv780[v1, v2 // 32]
                                    for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
                                        with T.block("NT_matmul_update"):
                                            v0 = T.axis.spatial(1, 0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                            v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
                                            v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
                                            T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv6_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3])
                                            T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                            var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv6_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3]
                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
                                    for ax2_1_1 in T.vectorized(2):
                                        with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
                                            v0 = T.axis.spatial(1, ax0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                            v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
                                            T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                            T.writes(var_NT_matmul_intermediate[0, v1, v2])
                                            if v1 < n:
                                                var_NT_matmul_intermediate[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_fused_decode2_NT_matmul6(lv3: T.Buffer((12288, 512), "uint32"), lv4: T.Buffer((12288, 128), "float16"), lv1615: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 12288), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 12288), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 12288), "float16", scope="local")
        lv3_local = T.alloc_buffer((12288, 512), "uint32", scope="local")
        lv1615_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(192, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(4):
                                        with T.block("lv1615_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
                                            T.reads(lv1615[v0, v1, v2])
                                            T.writes(lv1615_shared[v0, v1, v2])
                                            lv1615_shared[v0, v1, v2] = lv1615[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_0, ax1 in T.grid(1, 1):
                            for ax0_1 in T.vectorized(1):
                                with T.block("lv3_local"):
                                    v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
                                    v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
                                    T.reads(lv3[v0, v1])
                                    T.writes(lv3_local[v0, v1])
                                    lv3_local[v0, v1] = lv3[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1615_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv3_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv4[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
                                    T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1615_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv3_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv4[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
            for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_1_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
                                v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
                            for ax1 in range(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                    T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_1 in range(1):
                for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
                            v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
                            T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(var_NT_matmul_intermediate[0, 0, v0])
                            with T.init():
                                var_NT_matmul_intermediate[0, 0, v0] = T.float16(0)
                            var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]

    @T.prim_func(private=True)
    def fused_fused_decode3_fused_NT_matmul2_add1(lv784: T.Buffer((4096, 512), "uint32"), lv785: T.Buffer((4096, 128), "float16"), p_lv41: T.handle, p_lv2: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv41 = T.match_buffer(p_lv41, (1, n, 4096), "float16")
        lv2 = T.match_buffer(p_lv2, (1, n, 4096), "float16")
        p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16")
        # with T.block("root"):
        var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="local")
        lv41_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared")
        p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 4096, 4096), "float16", scope="shared")
        for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
            for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                        for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                for ax2_3_init, ax1_3_init in T.grid(4, 4):
                                    with T.block("NT_matmul_init"):
                                        v0 = T.axis.spatial(1, 0)
                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                        v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
                                        T.reads()
                                        T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                        var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0)
                                for ax3_0 in range(256):
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(2):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("lv41_reindex_pad_shared"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv41[v0, v1, v2])
                                                        T.writes(lv41_reindex_pad_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        lv41_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv41[v0, v1, v2], T.float16(0))
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(4):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("p_output0_intermediate_reindex_shared"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv784[v1, v2 // 8], lv785[v1, v2 // 32])
                                                        T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv784[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv785[v1, v2 // 32]
                                    for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
                                        with T.block("NT_matmul_update"):
                                            v0 = T.axis.spatial(1, 0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                            v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
                                            v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
                                            T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv41_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3])
                                            T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                            var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv41_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3]
                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
                                    for ax2_1_1 in T.vectorized(2):
                                        with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
                                            v0 = T.axis.spatial(1, ax0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                            v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
                                            T.reads(lv2[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                            T.writes(p_output0_intermediate[0, v1, v2])
                                            if v1 < n:
                                                p_output0_intermediate[0, v1, v2] = lv2[0, v1, v2] + var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_fused_decode3_fused_NT_matmul8_add(lv15: T.Buffer((4096, 512), "uint32"), lv16: T.Buffer((4096, 128), "float16"), lv14: T.Buffer((1, 1, 4096), "float16"), lv1613: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 4096), "float16", scope="local")
        lv15_local = T.alloc_buffer((4096, 512), "uint32", scope="local")
        lv14_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(4):
                                        with T.block("lv14_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
                                            T.reads(lv14[v0, v1, v2])
                                            T.writes(lv14_shared[v0, v1, v2])
                                            lv14_shared[v0, v1, v2] = lv14[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_0, ax1 in T.grid(1, 1):
                            for ax0_1 in T.vectorized(1):
                                with T.block("lv15_local"):
                                    v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
                                    v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
                                    T.reads(lv15[v0, v1])
                                    T.writes(lv15_local[v0, v1])
                                    lv15_local[v0, v1] = lv15[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv14_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv15_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv16[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
                                    T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv14_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv15_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv16[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
            for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_1_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
                                v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
                            for ax1 in range(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                    T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_1 in range(1):
                for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
                            v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
                            T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(var_NT_matmul_intermediate_local[0, 0, v0])
                            with T.init():
                                var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0)
                            var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
            for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                for ax0_fused_1 in range(1):
                    with T.block("T_add"):
                        v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1)
                        T.reads(lv1613[0, 0, v0], var_NT_matmul_intermediate_local[0, 0, v0])
                        T.writes(p_output0_intermediate[0, 0, v0])
                        p_output0_intermediate[0, 0, v0] = lv1613[0, 0, v0] + var_NT_matmul_intermediate_local[0, 0, v0]

    @T.prim_func(private=True)
    def fused_fused_decode4_NT_matmul3(lv788: T.Buffer((22016, 512), "uint32"), lv789: T.Buffer((22016, 128), "float16"), p_lv45: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16")
        var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 22016), "float16")
        # with T.block("root"):
        var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 22016), "float16", scope="local")
        lv45_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared")
        p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 22016, 4096), "float16", scope="shared")
        for ax0_ax2_0_fused in T.thread_binding(344, thread="blockIdx.y"):
            for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                        for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                for ax2_3_init, ax1_3_init in T.grid(4, 4):
                                    with T.block("NT_matmul_init"):
                                        v0 = T.axis.spatial(1, 0)
                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                        v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
                                        T.reads()
                                        T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                        var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0)
                                for ax3_0 in range(256):
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(2):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("lv45_reindex_pad_shared"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv45[v0, v1, v2])
                                                        T.writes(lv45_reindex_pad_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        lv45_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv45[v0, v1, v2], T.float16(0))
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(4):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("p_output0_intermediate_reindex_shared"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv788[v1, v2 // 8], lv789[v1, v2 // 32])
                                                        T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv788[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv789[v1, v2 // 32]
                                    for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
                                        with T.block("NT_matmul_update"):
                                            v0 = T.axis.spatial(1, 0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                            v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
                                            v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
                                            T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv45_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3])
                                            T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                            var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv45_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3]
                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
                                    for ax2_1_1 in T.vectorized(2):
                                        with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
                                            v0 = T.axis.spatial(1, ax0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                            v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
                                            T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                            T.writes(var_NT_matmul_intermediate[0, v1, v2])
                                            if v1 < n:
                                                var_NT_matmul_intermediate[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_fused_decode4_NT_matmul9(lv19: T.Buffer((22016, 512), "uint32"), lv20: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 22016), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 22016), "float16", scope="local")
        lv19_local = T.alloc_buffer((22016, 512), "uint32", scope="local")
        lv1654_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(344, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(4):
                                        with T.block("lv1654_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
                                            T.reads(lv1654[v0, v1, v2])
                                            T.writes(lv1654_shared[v0, v1, v2])
                                            lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_0, ax1 in T.grid(1, 1):
                            for ax0_1 in T.vectorized(1):
                                with T.block("lv19_local"):
                                    v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
                                    v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
                                    T.reads(lv19[v0, v1])
                                    T.writes(lv19_local[v0, v1])
                                    lv19_local[v0, v1] = lv19[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv19_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv20[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
                                    T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv19_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv20[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
            for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_1_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
                                v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
                            for ax1 in range(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                    T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_1 in range(1):
                for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
                            v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
                            T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(var_NT_matmul_intermediate[0, 0, v0])
                            with T.init():
                                var_NT_matmul_intermediate[0, 0, v0] = T.float16(0)
                            var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]

    @T.prim_func(private=True)
    def fused_fused_decode5_fused_NT_matmul10_add(lv23: T.Buffer((4096, 1376), "uint32"), lv24: T.Buffer((4096, 344), "float16"), lv22: T.Buffer((1, 1, 11008), "float16"), lv18: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local")
        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 4096), "float16", scope="local")
        lv23_local = T.alloc_buffer((4096, 1376), "uint32", scope="local")
        lv22_shared = T.alloc_buffer((1, 1, 11008), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
                for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(1, 1):
                        for ax2_0 in T.serial(22, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(1):
                                        with T.block("lv22_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(11008, ax2_0 * 512 + ax2_1 * 8 + ax2_2 + ax2_3)
                                            T.where((ax2_0 * 64 + ax2_1) * 8 + ax2_2 + ax2_3 < 11008)
                                            T.reads(lv22[v0, v1, v2])
                                            T.writes(lv22_shared[v0, v1, v2])
                                            lv22_shared[v0, v1, v2] = lv22[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(1):
                        for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
                                v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
                    for ax1_0_fused_ax1_1_fused_0 in T.serial(172, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_0, ax1 in T.grid(1, 1):
                            for ax0_1 in T.vectorized(1):
                                with T.block("lv23_local"):
                                    v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
                                    v1 = T.axis.spatial(1376, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
                                    T.reads(lv23[v0, v1])
                                    T.writes(lv23_local[v0, v1])
                                    lv23_local[v0, v1] = lv23[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
                            for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
                                    v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
                                    T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv22_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv23_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv24[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
                                    T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv22_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv23_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv24[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
            for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                    for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_1_1 in T.vectorized(1):
                            with T.block("NT_matmul_rf_init"):
                                vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
                                v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                T.reads()
                                T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
                            for ax1 in range(2):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
                                    T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
                                    T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                                    var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
            for ax1_fused_1 in range(1):
                for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                    for ax0 in T.thread_binding(8, thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
                            v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
                            T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
                            T.writes(var_NT_matmul_intermediate_local[0, 0, v0])
                            with T.init():
                                var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0)
                            var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
            for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
                for ax0_fused_1 in range(1):
                    with T.block("T_add"):
                        v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1)
                        T.reads(lv18[0, 0, v0], var_NT_matmul_intermediate_local[0, 0, v0])
                        T.writes(p_output0_intermediate[0, 0, v0])
                        p_output0_intermediate[0, 0, v0] = lv18[0, 0, v0] + var_NT_matmul_intermediate_local[0, 0, v0]

    @T.prim_func(private=True)
    def fused_fused_decode5_fused_NT_matmul4_add1(lv792: T.Buffer((4096, 1376), "uint32"), lv793: T.Buffer((4096, 344), "float16"), p_lv791: T.handle, p_lv787: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv791 = T.match_buffer(p_lv791, (1, n, 11008), "float16")
        lv787 = T.match_buffer(p_lv787, (1, n, 4096), "float16")
        p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16")
        # with T.block("root"):
        var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="local")
        lv791_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 11008), "float16", scope="shared")
        p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 4096, 11008), "float16", scope="shared")
        for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
            for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                        for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                for ax2_3_init, ax1_3_init in T.grid(4, 4):
                                    with T.block("NT_matmul_init"):
                                        v0 = T.axis.spatial(1, 0)
                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                        v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
                                        T.reads()
                                        T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                        var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0)
                                for ax3_0 in range(688):
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(2):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("lv791_reindex_pad_shared"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(11008, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv791[v0, v1, v2])
                                                        T.writes(lv791_reindex_pad_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        lv791_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv791[v0, v1, v2], T.float16(0))
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(4):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("p_output0_intermediate_reindex_shared"):
                                                        v0 = T.axis.spatial(1, 0)
                                                        v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial(11008, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(lv792[v1, v2 // 8], lv793[v1, v2 // 32])
                                                        T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv792[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv793[v1, v2 // 32]
                                    for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
                                        with T.block("NT_matmul_update"):
                                            v0 = T.axis.spatial(1, 0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                            v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
                                            v3 = T.axis.reduce(11008, ax3_0 * 16 + ax3_1)
                                            T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv791_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3])
                                            T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
                                            var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv791_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3]
                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
                                    for ax2_1_1 in T.vectorized(2):
                                        with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
                                            v0 = T.axis.spatial(1, ax0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                            v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
                                            T.reads(lv787[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
                                            T.writes(p_output0_intermediate[0, v1, v2])
                                            if v1 < n:
                                                p_output0_intermediate[0, v1, v2] = lv787[0, v1, v2] + var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int32):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (1, 1, n, n), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding((n * n + 255) // 256, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_broadcast_to"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // n)
                    v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % n)
                    T.where(ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 < n * n)
                    T.reads()
                    T.writes(var_T_broadcast_to_intermediate[0, 0, v0, v1])
                    var_T_broadcast_to_intermediate[0, 0, v0, v1] = T.Select(v0 < v1, T.float16(-65504), T.float16(65504))

    @T.prim_func(private=True)
    def fused_reshape2_squeeze(lv_1: T.Buffer((1, 1, 4096), "float16"), var_T_squeeze_intermediate: T.Buffer((1, 32, 128), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_squeeze"):
                    v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 128)
                    v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 128)
                    T.reads(lv_1[0, 0, v0 * 128 + v1])
                    T.writes(var_T_squeeze_intermediate[0, v0, v1])
                    var_T_squeeze_intermediate[0, v0, v1] = lv_1[0, 0, v0 * 128 + v1]

    @T.prim_func(private=True)
    def fused_reshape2_transpose5(lv_0: T.Buffer((1, 1, 4096), "float16"), var_T_transpose_intermediate: T.Buffer((1, 32, 1, 128), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_transpose"):
                    v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 128)
                    v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 128)
                    T.reads(lv_0[0, 0, v0 * 128 + v1])
                    T.writes(var_T_transpose_intermediate[0, v0, 0, v1])
                    var_T_transpose_intermediate[0, v0, 0, v1] = lv_0[0, 0, v0 * 128 + v1]

    @T.prim_func(private=True)
    def fused_softmax1_cast4(p_lv36: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n, m = T.int32(), T.int32()
        lv36 = T.match_buffer(p_lv36, (1, 32, n, m))
        var_compute_intermediate = T.match_buffer(p_output0, (1, 32, n, m), "float16")
        # with T.block("root"):
        T_softmax_maxelem_shared = T.alloc_buffer((1, 32, n), scope="shared")
        T_softmax_expsum_shared = T.alloc_buffer((1, 32, n), scope="shared")
        for ax0_ax1_fused in T.thread_binding(n * 32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
            for ax0, ax1, ax2_fused_0 in T.grid(1, 1, (m + 63) // 64):
                for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_softmax_maxelem"):
                        v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax0)
                        v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
                        v2 = T.axis.reduce(m, ax2_fused_0 * 64 + ax2_fused_1)
                        T.where(ax2_fused_0 * 64 + ax2_fused_1 < m)
                        T.reads(lv36[0, v0, v1, v2])
                        T.writes(T_softmax_maxelem_shared[0, v0, v1])
                        with T.init():
                            T_softmax_maxelem_shared[0, v0, v1] = T.float32(-3.4028234663852886e+38)
                        T_softmax_maxelem_shared[0, v0, v1] = T.max(T_softmax_maxelem_shared[0, v0, v1], lv36[0, v0, v1, v2])
            for ax0, ax1, ax2_fused_0 in T.grid(1, 1, (m + 63) // 64):
                for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_softmax_expsum"):
                        v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax0)
                        v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
                        v2 = T.axis.reduce(m, ax2_fused_0 * 64 + ax2_fused_1)
                        T.where(ax2_fused_0 * 64 + ax2_fused_1 < m)
                        T.reads(lv36[0, v0, v1, v2], T_softmax_maxelem_shared[0, v0, v1])
                        T.writes(T_softmax_expsum_shared[0, v0, v1])
                        with T.init():
                            T_softmax_expsum_shared[0, v0, v1] = T.float32(0)
                        T_softmax_expsum_shared[0, v0, v1] = T_softmax_expsum_shared[0, v0, v1] + T.exp(lv36[0, v0, v1, v2] - T_softmax_maxelem_shared[0, v0, v1])
            for ax2_0 in range((m + 63) // 64):
                for ax2_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("compute"):
                        v0 = T.axis.spatial(32, ax0_ax1_fused // n)
                        v1 = T.axis.spatial(n, ax0_ax1_fused % n)
                        v2 = T.axis.spatial(m, ax2_0 * 64 + ax2_1)
                        T.where(ax2_0 * 64 + ax2_1 < m)
                        T.reads(lv36[0, v0, v1, v2], T_softmax_maxelem_shared[0, v0, v1], T_softmax_expsum_shared[0, v0, v1])
                        T.writes(var_compute_intermediate[0, v0, v1, v2])
                        var_compute_intermediate[0, v0, v1, v2] = T.Cast("float16", T.exp(lv36[0, v0, v1, v2] - T_softmax_maxelem_shared[0, v0, v1]) / T_softmax_expsum_shared[0, v0, v1])

    @T.prim_func(private=True)
    def fused_softmax_cast1(p_lv1645: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv1645 = T.match_buffer(p_lv1645, (1, 32, 1, n))
        var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n), "float16")
        # with T.block("root"):
        T_softmax_maxelem_shared = T.alloc_buffer((1, 32, 1), scope="shared")
        T_softmax_expsum_shared = T.alloc_buffer((1, 32, 1), scope="shared")
        for ax0_fused in T.thread_binding(32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
            for ax0, ax1_fused_0 in T.grid(1, (n + 63) // 64):
                for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_softmax_maxelem"):
                        v0 = T.axis.spatial(32, ax0_fused + ax0)
                        v1 = T.axis.reduce(n, ax1_fused_0 * 64 + ax1_fused_1)
                        T.where(ax1_fused_0 * 64 + ax1_fused_1 < n)
                        T.reads(lv1645[0, v0, 0, v1])
                        T.writes(T_softmax_maxelem_shared[0, v0, 0])
                        with T.init():
                            T_softmax_maxelem_shared[0, v0, 0] = T.float32(-3.4028234663852886e+38)
                        T_softmax_maxelem_shared[0, v0, 0] = T.max(T_softmax_maxelem_shared[0, v0, 0], lv1645[0, v0, 0, v1])
            for ax0, ax1_fused_0 in T.grid(1, (n + 63) // 64):
                for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_softmax_expsum"):
                        v0 = T.axis.spatial(32, ax0_fused + ax0)
                        v1 = T.axis.reduce(n, ax1_fused_0 * 64 + ax1_fused_1)
                        T.where(ax1_fused_0 * 64 + ax1_fused_1 < n)
                        T.reads(lv1645[0, v0, 0, v1], T_softmax_maxelem_shared[0, v0, 0])
                        T.writes(T_softmax_expsum_shared[0, v0, 0])
                        with T.init():
                            T_softmax_expsum_shared[0, v0, 0] = T.float32(0)
                        T_softmax_expsum_shared[0, v0, 0] = T_softmax_expsum_shared[0, v0, 0] + T.exp(lv1645[0, v0, 0, v1] - T_softmax_maxelem_shared[0, v0, 0])
            for ax1_0 in range((n + 63) // 64):
                for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("compute"):
                        v0 = T.axis.spatial(32, ax0_fused)
                        v1 = T.axis.spatial(n, ax1_0 * 64 + ax1_1)
                        T.where(ax1_0 * 64 + ax1_1 < n)
                        T.reads(lv1645[0, v0, 0, v1], T_softmax_maxelem_shared[0, v0, 0], T_softmax_expsum_shared[0, v0, 0])
                        T.writes(var_compute_intermediate[0, v0, 0, v1])
                        var_compute_intermediate[0, v0, 0, v1] = T.Cast("float16", T.exp(lv1645[0, v0, 0, v1] - T_softmax_maxelem_shared[0, v0, 0]) / T_softmax_expsum_shared[0, v0, 0])

    @T.prim_func(private=True)
    def fused_split2_silu1_multiply1(p_lv3: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv3 = T.match_buffer(p_lv3, (1, n, 22016), "float16")
        var_T_multiply_intermediate = T.match_buffer(p_output0, (1, n, 11008), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(n * 43, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_multiply_1"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 11008)
                    v1 = T.axis.spatial(11008, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 11008)
                    T.reads(lv3[0, v0, v1:v1 + 11009])
                    T.writes(var_T_multiply_intermediate[0, v0, v1])
                    var_T_multiply_intermediate[0, v0, v1] = lv3[0, v0, v1] * T.sigmoid(lv3[0, v0, v1]) * lv3[0, v0, v1 + 11008]

    @T.prim_func(private=True)
    def fused_split_silu_multiply(lv164: T.Buffer((1, 1, 22016), "float16"), var_T_multiply_intermediate: T.Buffer((1, 1, 11008), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(43, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_multiply_1"):
                    v0 = T.axis.spatial(11008, ax0_fused_0 * 256 + ax0_fused_1)
                    T.reads(lv164[0, 0, v0:v0 + 11009])
                    T.writes(var_T_multiply_intermediate[0, 0, v0])
                    var_T_multiply_intermediate[0, 0, v0] = lv164[0, 0, v0] * T.sigmoid(lv164[0, 0, v0]) * lv164[0, 0, v0 + 11008]

    @T.prim_func(private=True)
    def fused_transpose7_reshape4(lv1648: T.Buffer((1, 32, 1, 128), "float16"), var_T_reshape_intermediate: T.Buffer((1, 1, 4096), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
                    T.reads(lv1648[0, v0 // 128, 0, v0 % 128])
                    T.writes(var_T_reshape_intermediate[0, 0, v0])
                    var_T_reshape_intermediate[0, 0, v0] = lv1648[0, v0 // 128, 0, v0 % 128]

    @T.prim_func(private=True)
    def matmul10(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n, m = T.int32(), T.int32()
        A = T.match_buffer(var_A, (1, 32, n, m), "float16")
        B = T.match_buffer(var_B, (1, 32, m, 128), "float16")
        matmul = T.match_buffer(var_matmul, (1, 32, n, 128), "float16")
        # with T.block("root"):
        matmul_reindex_pad_local = T.alloc_buffer((32, (n + 31) // 32 * 32, 128), "float16", scope="local")
        A_reindex_pad_shared = T.alloc_buffer((32, (n + 31) // 32 * 32, (m + 15) // 16 * 16), "float16", scope="shared")
        B_reindex_pad_shared = T.alloc_buffer((32, 128, (m + 15) // 16 * 16), "float16", scope="shared")
        for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
            for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                        for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                                for ax2_3_init, ax1_3_init in T.grid(4, 4):
                                    with T.block("matmul_init"):
                                        v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2)
                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
                                        v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
                                        T.reads()
                                        T.writes(matmul_reindex_pad_local[v0, v1, v2])
                                        matmul_reindex_pad_local[v0, v1, v2] = T.float16(0)
                                for ax3_0 in range((m + 15) // 16):
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(2):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("A_reindex_pad_shared"):
                                                        v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2)
                                                        v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(A[0, v0, v1, v2])
                                                        T.writes(A_reindex_pad_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n and v2 < m, A[0, v0, v1, v2], T.float16(0))
                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                            for ax0_ax1_ax2_fused_2 in range(4):
                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                    with T.block("B_reindex_pad_shared"):
                                                        v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2)
                                                        v1 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                        v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                        T.reads(B[0, v0, v2, v1])
                                                        T.writes(B_reindex_pad_shared[v0, v1, v2])
                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                        B_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < m, B[0, v0, v2, v1], T.float16(0))
                                    for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
                                        with T.block("matmul_update"):
                                            v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
                                            v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
                                            v3 = T.axis.reduce((m + 15) // 16 * 16, ax3_0 * 16 + ax3_1)
                                            T.reads(matmul_reindex_pad_local[v0, v1, v2], A_reindex_pad_shared[v0, v1, v3], B_reindex_pad_shared[v0, v2, v3])
                                            T.writes(matmul_reindex_pad_local[v0, v1, v2])
                                            matmul_reindex_pad_local[v0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2] + A_reindex_pad_shared[v0, v1, v3] * B_reindex_pad_shared[v0, v2, v3]
                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
                                    for ax2_1_1 in T.vectorized(2):
                                        with T.block("matmul_reindex_pad_local"):
                                            v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2 + ax0)
                                            v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                            v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
                                            T.reads(matmul_reindex_pad_local[v0, v1, v2])
                                            T.writes(matmul[0, v0, v1, v2])
                                            if v1 < n:
                                                matmul[0, v0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]

    @T.prim_func(private=True)
    def matmul9(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((1, 32, 1, 128), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, 32, 1, n), "float16")
        B = T.match_buffer(var_B, (1, 32, n, 128), "float16")
        # with T.block("root"):
        matmul_rf_local = T.alloc_buffer((16, 1, 32, 1, 128), "float16", scope="local")
        for ax0_ax1_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
                for ax2_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                    with T.block("matmul_rf_init"):
                        vax2_fused_1 = T.axis.spatial(16, ax2_fused_1)
                        v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) // 128)
                        v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) % 128)
                        T.reads()
                        T.writes(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1])
                        matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0)
                    for ax2_fused_0, u in T.grid((n + 15) // 16, 1):
                        with T.block("matmul_rf_update"):
                            vax2_fused_1 = T.axis.spatial(16, ax2_fused_1)
                            v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) // 128)
                            v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) % 128)
                            vax2_fused_0 = T.axis.reduce((n + 15) // 16, ax2_fused_0)
                            T.where(ax2_fused_0 * 16 + ax2_fused_1 < n)
                            T.reads(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1], A[0, v0, 0, vax2_fused_0 * 16 + vax2_fused_1], B[0, v0, vax2_fused_0 * 16 + vax2_fused_1, v1])
                            T.writes(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1])
                            matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] = matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] + A[0, v0, 0, vax2_fused_0 * 16 + vax2_fused_1] * B[0, v0, vax2_fused_0 * 16 + vax2_fused_1, v1]
            for ax1_ax2_fused in T.thread_binding(16, thread="threadIdx.x"):
                for ax0 in T.thread_binding(16, thread="threadIdx.y"):
                    with T.block("matmul"):
                        vax2_fused_1 = T.axis.reduce(16, ax0)
                        v0 = T.axis.spatial(32, ax0_ax1_fused_0 // 8)
                        v1 = T.axis.spatial(128, ax0_ax1_fused_0 % 8 * 16 + ax1_ax2_fused)
                        T.reads(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1])
                        T.writes(matmul[0, v0, 0, v1])
                        with T.init():
                            matmul[0, v0, 0, v1] = T.float16(0)
                        matmul[0, v0, 0, v1] = matmul[0, v0, 0, v1] + matmul_rf_local[vax2_fused_1, 0, v0, 0, v1]

    @T.prim_func(private=True)
    def reshape(A: T.Buffer((1, 1), "int32"), T_reshape: T.Buffer((1,), "int32")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(1, 0)
                    T.where(ax0_fused_0 * 256 + ax0_fused_1 < 1)
                    T.reads(A[0, 0])
                    T.writes(T_reshape[0])
                    T_reshape[0] = A[0, 0]

    @T.prim_func(private=True)
    def reshape1(A: T.Buffer((1, 4096), "float16"), T_reshape: T.Buffer((1, 1, 4096), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
                    T.reads(A[0, v0])
                    T.writes(T_reshape[0, 0, v0])
                    T_reshape[0, 0, v0] = A[0, v0]

    @T.prim_func(private=True)
    def reshape3(var_A: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (n, 32, 128), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (1, n, 32, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
                    v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
                    T.reads(A[v0, v1, v2])
                    T.writes(T_reshape[0, v0, v1, v2])
                    T_reshape[0, v0, v1, v2] = A[v0, v1, v2]

    @T.prim_func(private=True)
    def reshape5(var_A: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n), "int32")
        T_reshape = T.match_buffer(var_T_reshape, (n,), "int32")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding((n + 255) // 256, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(n, ax0_fused_0 * 256 + ax0_fused_1)
                    T.where(ax0_fused_0 * 256 + ax0_fused_1 < n)
                    T.reads(A[0, v0])
                    T.writes(T_reshape[v0])
                    T_reshape[v0] = A[0, v0]

    @T.prim_func(private=True)
    def reshape6(var_A: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (n, 4096), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (1, n, 4096), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
                    v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
                    T.reads(A[v0, v1])
                    T.writes(T_reshape[0, v0, v1])
                    T_reshape[0, v0, v1] = A[v0, v1]

    @T.prim_func(private=True)
    def reshape7(var_A: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 4096), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (1, n, 32, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
                    v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
                    T.reads(A[0, v0, v1 * 128 + v2])
                    T.writes(T_reshape[0, v0, v1, v2])
                    T_reshape[0, v0, v1, v2] = A[0, v0, v1 * 128 + v2]

    @T.prim_func(private=True)
    def reshape8(var_A: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 32, 128), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (1, n, 4096), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_reshape"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
                    v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
                    T.reads(A[0, v0, v1 // 128, v1 % 128])
                    T.writes(T_reshape[0, v0, v1])
                    T_reshape[0, v0, v1] = A[0, v0, v1 // 128, v1 % 128]

    @T.prim_func(private=True)
    def rms_norm(var_A: T.handle, B: T.Buffer((4096,), "float16"), var_rms_norm: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 4096), "float16")
        rms_norm_1 = T.match_buffer(var_rms_norm, (1, n, 4096), "float16")
        # with T.block("root"):
        Ared_temp_shared = T.alloc_buffer((1, n), scope="shared")
        Ared_temp_rf_local = T.alloc_buffer((64, 1, n), scope="local")
        for ax0_fused in T.thread_binding(n, thread="blockIdx.x"):
            for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                with T.block("Ared_temp_rf_init"):
                    vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused])
                    T.reads()
                    T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0])
                    Ared_temp_rf_local[vax1_fused_1, 0, v0] = T.float32(0)
                for ax1_fused_0, u in T.grid(64, 1):
                    with T.block("Ared_temp_rf_update"):
                        vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0])
                        T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0], A[0, v0, vax1_fused_0 * 64 + vax1_fused_1])
                        T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0])
                        Ared_temp_rf_local[vax1_fused_1, 0, v0] = Ared_temp_rf_local[vax1_fused_1, 0, v0] + T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1])
            for ax1_fused in range(1):
                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("Ared_temp"):
                        vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused])
                        T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0])
                        T.writes(Ared_temp_shared[0, v0])
                        with T.init():
                            Ared_temp_shared[0, v0] = T.float32(0)
                        Ared_temp_shared[0, v0] = Ared_temp_shared[0, v0] + Ared_temp_rf_local[vax1_fused_1, 0, v0]
            for ax0_fused_0 in range(64):
                for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("rms_norm"):
                        v0 = T.axis.spatial(n, ax0_fused)
                        v1 = T.axis.spatial(4096, ax0_fused_0 * 64 + ax0_fused_1)
                        T.reads(B[v1], A[0, v0, v1], Ared_temp_shared[0, v0])
                        T.writes(rms_norm_1[0, v0, v1])
                        rms_norm_1[0, v0, v1] = T.Cast("float16", T.Cast("float32", B[v1]) * (T.Cast("float32", A[0, v0, v1]) / T.sqrt(Ared_temp_shared[0, v0] * T.float32(0.000244140625) + T.float32(1.0000000000000001e-05))))

    @T.prim_func(private=True)
    def rms_norm1(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), rms_norm: T.Buffer((1, 1, 4096), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared")
        Ared_temp_rf_local = T.alloc_buffer((64, 1, 1), scope="local")
        for ax0_fused in T.thread_binding(1, thread="blockIdx.x"):
            for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                with T.block("Ared_temp_rf_init"):
                    vax1_fused_1 = T.axis.spatial(64, ax1_fused_1)
                    v0 = T.axis.spatial(1, 0)
                    T.reads()
                    T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0])
                    Ared_temp_rf_local[vax1_fused_1, 0, 0] = T.float32(0)
                for ax1_fused_0, u in T.grid(64, 1):
                    with T.block("Ared_temp_rf_update"):
                        vax1_fused_1 = T.axis.spatial(64, ax1_fused_1)
                        v0 = T.axis.spatial(1, 0)
                        vax1_fused_0 = T.axis.reduce(64, ax1_fused_0)
                        T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0], A[0, 0, vax1_fused_0 * 64 + vax1_fused_1])
                        T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0])
                        Ared_temp_rf_local[vax1_fused_1, 0, 0] = Ared_temp_rf_local[vax1_fused_1, 0, 0] + T.Cast("float32", A[0, 0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", A[0, 0, vax1_fused_0 * 64 + vax1_fused_1])
            for ax1_fused in range(1):
                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("Ared_temp"):
                        vax1_fused_1 = T.axis.reduce(64, ax0)
                        v0 = T.axis.spatial(1, 0)
                        T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0])
                        T.writes(Ared_temp_shared[0, 0])
                        with T.init():
                            Ared_temp_shared[0, 0] = T.float32(0)
                        Ared_temp_shared[0, 0] = Ared_temp_shared[0, 0] + Ared_temp_rf_local[vax1_fused_1, 0, 0]
            for ax0_fused_0 in range(64):
                for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("rms_norm"):
                        v0 = T.axis.spatial(4096, ax0_fused_0 * 64 + ax0_fused_1)
                        T.reads(B[v0], A[0, 0, v0], Ared_temp_shared[0, 0])
                        T.writes(rms_norm[0, 0, v0])
                        rms_norm[0, 0, v0] = T.Cast("float16", T.Cast("float32", B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp_shared[0, 0] * T.float32(0.000244140625) + T.float32(1.0000000000000001e-05))))

    @T.prim_func(private=True)
    def rotary_embedding(var_A: T.handle, B: T.Buffer((2048, 128), "float16"), C: T.Buffer((2048, 128), "float16"), var_rotary: T.handle, m: T.int32):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 32, 128), "float16")
        rotary = T.match_buffer(var_rotary, (1, n, 32, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("rotary"):
                    v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
                    v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
                    T.reads(B[v0 + (m - n), v2], A[0, v0, v1, v2 + -64:v2 + -64 + 129], C[v0 + (m - n), v2])
                    T.writes(rotary[0, v0, v1, v2])
                    rotary[0, v0, v1, v2] = B[v0 + (m - n), v2] * A[0, v0, v1, v2] + C[v0 + (m - n), v2] * T.Select(64 <= v2, A[0, v0, v1, v2 + -64], A[0, v0, v1, v2 + 64] * T.float16(-1))

    @T.prim_func(private=True)
    def slice(var_A: T.handle, slice_1: T.Buffer((1, 1, 4096), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 4096), "float16")
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("slice"):
                    v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
                    T.reads(A[0, n - 1, v0])
                    T.writes(slice_1[0, 0, v0])
                    slice_1[0, 0, v0] = A[0, n - 1, v0]

    @T.prim_func(private=True)
    def slice1(A: T.Buffer((1, 1, 4096), "float16"), slice: T.Buffer((1, 1, 4096), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("slice"):
                    v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
                    T.reads(A[0, 0, v0])
                    T.writes(slice[0, 0, v0])
                    slice[0, 0, v0] = A[0, 0, v0]

    @T.prim_func(private=True)
    def softmax2(A: T.Buffer((1, 1, 32001), "float32"), T_softmax_norm: T.Buffer((1, 1, 32001), "float32")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_softmax_maxelem_shared = T.alloc_buffer((1, 1), scope="shared")
        T_softmax_expsum_shared = T.alloc_buffer((1, 1), scope="shared")
        for ax0_fused in T.thread_binding(1, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
            for ax0, ax1_fused_0 in T.grid(1, 501):
                for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_softmax_maxelem"):
                        v0 = T.axis.spatial(1, ax0)
                        v1 = T.axis.reduce(32001, ax1_fused_0 * 64 + ax1_fused_1)
                        T.where(ax1_fused_0 * 64 + ax1_fused_1 < 32001)
                        T.reads(A[0, 0, v1])
                        T.writes(T_softmax_maxelem_shared[0, 0])
                        with T.init():
                            T_softmax_maxelem_shared[0, 0] = T.float32(-3.4028234663852886e+38)
                        T_softmax_maxelem_shared[0, 0] = T.max(T_softmax_maxelem_shared[0, 0], A[0, 0, v1])
            for ax0, ax1_fused_0 in T.grid(1, 501):
                for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_softmax_expsum"):
                        v0 = T.axis.spatial(1, ax0)
                        v1 = T.axis.reduce(32001, ax1_fused_0 * 64 + ax1_fused_1)
                        T.where(ax1_fused_0 * 64 + ax1_fused_1 < 32001)
                        T.reads(A[0, 0, v1], T_softmax_maxelem_shared[0, 0])
                        T.writes(T_softmax_expsum_shared[0, 0])
                        with T.init():
                            T_softmax_expsum_shared[0, 0] = T.float32(0)
                        T_softmax_expsum_shared[0, 0] = T_softmax_expsum_shared[0, 0] + T.exp(A[0, 0, v1] - T_softmax_maxelem_shared[0, 0])
            for ax1_0 in range(501):
                for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_softmax_norm"):
                        v0 = T.axis.spatial(1, 0)
                        v1 = T.axis.spatial(32001, ax1_0 * 64 + ax1_1)
                        T.where(ax1_0 * 64 + ax1_1 < 32001)
                        T.reads(A[0, 0, v1], T_softmax_maxelem_shared[0, 0], T_softmax_expsum_shared[0, 0])
                        T.writes(T_softmax_norm[0, 0, v1])
                        T.block_attr({"axis": 2})
                        T_softmax_norm[0, 0, v1] = T.exp(A[0, 0, v1] - T_softmax_maxelem_shared[0, 0]) / T_softmax_expsum_shared[0, 0]

    @T.prim_func(private=True)
    def split1(var_A: T.handle, var_T_split: T.handle, var_T_split_1: T.handle, var_T_split_2: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 12288), "float16")
        T_split = T.match_buffer(var_T_split, (1, n, 4096), "float16")
        T_split_1 = T.match_buffer(var_T_split_1, (1, n, 4096), "float16")
        T_split_2 = T.match_buffer(var_T_split_2, (1, n, 4096), "float16")
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_split"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
                    v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
                    T.reads(A[0, v0, v1])
                    T.writes(T_split[0, v0, v1])
                    T_split[0, v0, v1] = A[0, v0, v1]
        for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_split_1"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
                    v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
                    T.reads(A[0, v0, v1 + 4096])
                    T.writes(T_split_1[0, v0, v1])
                    T_split_1[0, v0, v1] = A[0, v0, v1 + 4096]
        for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_split_2"):
                    v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
                    v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
                    T.reads(A[0, v0, v1 + 8192])
                    T.writes(T_split_2[0, v0, v1])
                    T_split_2[0, v0, v1] = A[0, v0, v1 + 8192]

    @T.prim_func
    def split_rotary(A: T.Buffer((1, 1, 12288), "float16"), cos: T.Buffer((2048, 128), "float16"), sin: T.Buffer((2048, 128), "float16"), T_split: T.Buffer((1, 1, 4096), "float16"), T_split_1: T.Buffer((1, 1, 4096), "float16"), T_split_2: T.Buffer((1, 1, 4096), "float16"), n: T.int32):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_split"):
                    v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
                    T.reads(A[0, 0, v0], A[0, 0, v0 + 4096], A[0, 0, v0 + 8192])
                    T.writes(T_split[0, 0, v0], T_split_1[0, 0, v0], T_split_2[0, 0, v0])
                    T_split[0, 0, v0] = cos[n - 1, v0 % 128] * A[0, 0, v0] + sin[n - 1, v0 % 128] * T.Select(64 <= v0 % 128, A[0, 0, v0 + -64], A[0, 0, v0 + 64] * T.float16(-1))
                    T_split_1[0, 0, v0] = cos[n - 1, v0 % 128] * A[0, 0, v0 + 4096] + sin[n - 1, v0 % 128] * T.Select(64 <= v0 % 128, A[0, 0, v0 + 4032], A[0, 0, v0 + 4160] * T.float16(-1))
                    T_split_2[0, 0, v0] = A[0, 0, v0 + 8192]

    @T.prim_func(private=True)
    def squeeze1(var_A: T.handle, var_T_squeeze: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 32, 128), "float16")
        T_squeeze = T.match_buffer(var_T_squeeze, (n, 32, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_squeeze"):
                    v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
                    v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
                    T.reads(A[0, v0, v1, v2])
                    T.writes(T_squeeze[v0, v1, v2])
                    T_squeeze[v0, v1, v2] = A[0, v0, v1, v2]

    @T.prim_func(private=True)
    def transpose6(var_A: T.handle, var_T_transpose: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 32, 128), "float16")
        T_transpose = T.match_buffer(var_T_transpose, (1, 32, n, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_transpose"):
                    v0 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // (128 * n))
                    v1 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % (128 * n) // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
                    T.reads(A[0, v1, v0, v2])
                    T.writes(T_transpose[0, v0, v1, v2])
                    T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2]

    @T.prim_func(private=True)
    def transpose8(var_A: T.handle, var_T_transpose: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, 32, n, 128), "float16")
        T_transpose = T.match_buffer(var_T_transpose, (1, n, 32, 128), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("T_transpose"):
                    v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
                    v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
                    v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
                    T.reads(A[0, v1, v0, v2])
                    T.writes(T_transpose[0, v0, v1, v2])
                    T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2]

    @R.function
    def create_kv_cache() -> R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object):
        R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
        with R.dataflow():
            lv3221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3222: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3224: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3225: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3226: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3227: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3228: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3229: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3230: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3231: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3232: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3233: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3234: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3235: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3236: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3237: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3238: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3239: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3240: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3241: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3242: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3243: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3244: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3245: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3246: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3247: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3248: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3249: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3250: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3251: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3252: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3253: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3254: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3255: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3256: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3257: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3258: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3259: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3260: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3261: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3262: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3263: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3264: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3265: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3266: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3267: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3268: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3269: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3270: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3272: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3274: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3275: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3276: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3277: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3278: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3279: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3280: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3281: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3282: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3283: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            lv3284: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
            gv2: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object) = lv3221, lv3222, lv3223, lv3224, lv3225, lv3226, lv3227, lv3228, lv3229, lv3230, lv3231, lv3232, lv3233, lv3234, lv3235, lv3236, lv3237, lv3238, lv3239, lv3240, lv3241, lv3242, lv3243, lv3244, lv3245, lv3246, lv3247, lv3248, lv3249, lv3250, lv3251, lv3252, lv3253, lv3254, lv3255, lv3256, lv3257, lv3258, lv3259, lv3260, lv3261, lv3262, lv3263, lv3264, lv3265, lv3266, lv3267, lv3268, lv3269, lv3270, lv3271, lv3272, lv3273, lv3274, lv3275, lv3276, lv3277, lv3278, lv3279, lv3280, lv3281, lv3282, lv3283, lv3284
            R.output(gv2)
        return gv2

    @R.function
    def decode(input_ids1: R.Tensor((1, 1), dtype="int32"), all_seq_len: R.Shape(["n"]), kv_cache: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object), params: R.Tuple(R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)):
        n = T.int64()
        R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
        cls = Module
        with R.dataflow():
            lv1611 = R.call_tir(cls.reshape, (input_ids1,), out_sinfo=R.Tensor((1,), dtype="int32"))
            lv: R.Tensor((32001, 512), dtype="uint32") = params[0]
            lv1: R.Tensor((32001, 128), dtype="float16") = params[1]
            lv1_1 = R.call_tir(cls.fused_fused_decode1_take, (lv, lv1, lv1611), out_sinfo=R.Tensor((1, 4096), dtype="float16"))
            lv1613 = R.call_tir(cls.reshape1, (lv1_1,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv1614 = R.call_tir(cls.full, R.tuple(), out_sinfo=R.Tensor((1, 1, 1, n), dtype="float16"))
            lv460: R.Tensor((4096,), dtype="float16") = params[10]
            lv1615 = R.call_tir(cls.rms_norm1, (lv1613, lv460), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv3: R.Tensor((12288, 512), dtype="uint32") = params[2]
            lv4: R.Tensor((12288, 128), dtype="float16") = params[3]
            lv_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv3, lv4, lv1615), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv464: R.Tensor((2048, 128), dtype="float16") = params[325]
            lv465: R.Tensor((2048, 128), dtype="float16") = params[326]
            lv_2 = R.call_tir(cls.split_rotary, (lv_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv6: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[0]
            lv7 = R.call_tir(cls.fused_reshape2_transpose5, (lv6,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv8: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[1]
            lv9 = R.call_tir(cls.fused_reshape2_squeeze, (lv8,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv10: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[2]
            lv11 = R.call_tir(cls.fused_reshape2_squeeze, (lv10,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv1629: R.Object = kv_cache[0]
            lv1630: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1629, lv9, sinfo_args=(R.Object,))
            lv1631: R.Object = kv_cache[1]
            lv1632: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1631, lv11, sinfo_args=(R.Object,))
            lv1633: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1630, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1634: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1632, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1635 = R.call_tir(cls.reshape3, (lv1633,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1636 = R.call_tir(cls.reshape3, (lv1634,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1638 = R.call_tir(cls.transpose6, (lv1635,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1639 = R.call_tir(cls.transpose6, (lv1636,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv12 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv7, lv1638, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv13 = R.call_tir(cls.fused_softmax_cast1, (lv12,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv1648 = R.call_tir(cls.matmul9, (lv13, lv1639), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv14 = R.call_tir(cls.fused_transpose7_reshape4, (lv1648,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv15: R.Tensor((4096, 512), dtype="uint32") = params[4]
            lv16: R.Tensor((4096, 128), dtype="float16") = params[5]
            lv_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv15, lv16, lv14, lv1613), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv469: R.Tensor((4096,), dtype="float16") = params[11]
            lv1654 = R.call_tir(cls.rms_norm1, (lv_3, lv469), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv19: R.Tensor((22016, 512), dtype="uint32") = params[6]
            lv20: R.Tensor((22016, 128), dtype="float16") = params[7]
            lv1_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv19, lv20, lv1654), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv22 = R.call_tir(cls.fused_split_silu_multiply, (lv1_2,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv23: R.Tensor((4096, 1376), dtype="uint32") = params[8]
            lv24: R.Tensor((4096, 344), dtype="float16") = params[9]
            lv1_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv23, lv24, lv22, lv_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv476: R.Tensor((4096,), dtype="float16") = params[20]
            lv1665 = R.call_tir(cls.rms_norm1, (lv1_3, lv476), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv27: R.Tensor((12288, 512), dtype="uint32") = params[12]
            lv28: R.Tensor((12288, 128), dtype="float16") = params[13]
            lv2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv27, lv28, lv1665), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv4_1 = R.call_tir(cls.split_rotary, (lv2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv30: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[0]
            lv31 = R.call_tir(cls.fused_reshape2_transpose5, (lv30,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv32: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[1]
            lv33 = R.call_tir(cls.fused_reshape2_squeeze, (lv32,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv34: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[2]
            lv35 = R.call_tir(cls.fused_reshape2_squeeze, (lv34,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv1679: R.Object = kv_cache[2]
            lv1680: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1679, lv33, sinfo_args=(R.Object,))
            lv1681: R.Object = kv_cache[3]
            lv1682: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1681, lv35, sinfo_args=(R.Object,))
            lv1683: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1680, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1684: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1682, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1685 = R.call_tir(cls.reshape3, (lv1683,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1686 = R.call_tir(cls.reshape3, (lv1684,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1688 = R.call_tir(cls.transpose6, (lv1685,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1689 = R.call_tir(cls.transpose6, (lv1686,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv36 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv31, lv1688, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv37 = R.call_tir(cls.fused_softmax_cast1, (lv36,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv1698 = R.call_tir(cls.matmul9, (lv37, lv1689), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv38 = R.call_tir(cls.fused_transpose7_reshape4, (lv1698,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv39: R.Tensor((4096, 512), dtype="uint32") = params[14]
            lv40: R.Tensor((4096, 128), dtype="float16") = params[15]
            lv2_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv39, lv40, lv38, lv1_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv483: R.Tensor((4096,), dtype="float16") = params[21]
            lv1704 = R.call_tir(cls.rms_norm1, (lv2_1, lv483), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv43: R.Tensor((22016, 512), dtype="uint32") = params[16]
            lv44: R.Tensor((22016, 128), dtype="float16") = params[17]
            lv3_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv43, lv44, lv1704), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv46 = R.call_tir(cls.fused_split_silu_multiply, (lv3_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv47: R.Tensor((4096, 1376), dtype="uint32") = params[18]
            lv48: R.Tensor((4096, 344), dtype="float16") = params[19]
            lv3_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv47, lv48, lv46, lv2_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv490: R.Tensor((4096,), dtype="float16") = params[30]
            lv1715 = R.call_tir(cls.rms_norm1, (lv3_2, lv490), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv51: R.Tensor((12288, 512), dtype="uint32") = params[22]
            lv52: R.Tensor((12288, 128), dtype="float16") = params[23]
            lv4_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv51, lv52, lv1715), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv8_1 = R.call_tir(cls.split_rotary, (lv4_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv54: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[0]
            lv55 = R.call_tir(cls.fused_reshape2_transpose5, (lv54,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv56: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[1]
            lv57 = R.call_tir(cls.fused_reshape2_squeeze, (lv56,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv58: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[2]
            lv59 = R.call_tir(cls.fused_reshape2_squeeze, (lv58,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv1729: R.Object = kv_cache[4]
            lv1730: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1729, lv57, sinfo_args=(R.Object,))
            lv1731: R.Object = kv_cache[5]
            lv1732: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1731, lv59, sinfo_args=(R.Object,))
            lv1733: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1730, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1734: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1732, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1735 = R.call_tir(cls.reshape3, (lv1733,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1736 = R.call_tir(cls.reshape3, (lv1734,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1738 = R.call_tir(cls.transpose6, (lv1735,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1739 = R.call_tir(cls.transpose6, (lv1736,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv60 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv55, lv1738, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv61 = R.call_tir(cls.fused_softmax_cast1, (lv60,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv1748 = R.call_tir(cls.matmul9, (lv61, lv1739), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv62 = R.call_tir(cls.fused_transpose7_reshape4, (lv1748,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv63: R.Tensor((4096, 512), dtype="uint32") = params[24]
            lv64: R.Tensor((4096, 128), dtype="float16") = params[25]
            lv4_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv63, lv64, lv62, lv3_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv497: R.Tensor((4096,), dtype="float16") = params[31]
            lv1754 = R.call_tir(cls.rms_norm1, (lv4_3, lv497), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv67: R.Tensor((22016, 512), dtype="uint32") = params[26]
            lv68: R.Tensor((22016, 128), dtype="float16") = params[27]
            lv5 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv67, lv68, lv1754), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv70 = R.call_tir(cls.fused_split_silu_multiply, (lv5,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv71: R.Tensor((4096, 1376), dtype="uint32") = params[28]
            lv72: R.Tensor((4096, 344), dtype="float16") = params[29]
            lv5_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv71, lv72, lv70, lv4_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv504: R.Tensor((4096,), dtype="float16") = params[40]
            lv1765 = R.call_tir(cls.rms_norm1, (lv5_1, lv504), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv75: R.Tensor((12288, 512), dtype="uint32") = params[32]
            lv76: R.Tensor((12288, 128), dtype="float16") = params[33]
            lv6_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv75, lv76, lv1765), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv12_1 = R.call_tir(cls.split_rotary, (lv6_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv78: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[0]
            lv79 = R.call_tir(cls.fused_reshape2_transpose5, (lv78,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv80: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[1]
            lv81 = R.call_tir(cls.fused_reshape2_squeeze, (lv80,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv82: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[2]
            lv83 = R.call_tir(cls.fused_reshape2_squeeze, (lv82,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv1779: R.Object = kv_cache[6]
            lv1780: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1779, lv81, sinfo_args=(R.Object,))
            lv1781: R.Object = kv_cache[7]
            lv1782: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1781, lv83, sinfo_args=(R.Object,))
            lv1783: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1780, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1784: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1782, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1785 = R.call_tir(cls.reshape3, (lv1783,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1786 = R.call_tir(cls.reshape3, (lv1784,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1788 = R.call_tir(cls.transpose6, (lv1785,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1789 = R.call_tir(cls.transpose6, (lv1786,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv84 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv79, lv1788, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv85 = R.call_tir(cls.fused_softmax_cast1, (lv84,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv1798 = R.call_tir(cls.matmul9, (lv85, lv1789), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv86 = R.call_tir(cls.fused_transpose7_reshape4, (lv1798,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv87: R.Tensor((4096, 512), dtype="uint32") = params[34]
            lv88: R.Tensor((4096, 128), dtype="float16") = params[35]
            lv6_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv87, lv88, lv86, lv5_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv511: R.Tensor((4096,), dtype="float16") = params[41]
            lv1804 = R.call_tir(cls.rms_norm1, (lv6_2, lv511), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv91: R.Tensor((22016, 512), dtype="uint32") = params[36]
            lv92: R.Tensor((22016, 128), dtype="float16") = params[37]
            lv7_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv91, lv92, lv1804), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv94 = R.call_tir(cls.fused_split_silu_multiply, (lv7_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv95: R.Tensor((4096, 1376), dtype="uint32") = params[38]
            lv96: R.Tensor((4096, 344), dtype="float16") = params[39]
            lv7_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv95, lv96, lv94, lv6_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv518: R.Tensor((4096,), dtype="float16") = params[50]
            lv1815 = R.call_tir(cls.rms_norm1, (lv7_2, lv518), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv99: R.Tensor((12288, 512), dtype="uint32") = params[42]
            lv100: R.Tensor((12288, 128), dtype="float16") = params[43]
            lv8_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv99, lv100, lv1815), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv16_1 = R.call_tir(cls.split_rotary, (lv8_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv102: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[0]
            lv103 = R.call_tir(cls.fused_reshape2_transpose5, (lv102,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv104: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[1]
            lv105 = R.call_tir(cls.fused_reshape2_squeeze, (lv104,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv106: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[2]
            lv107 = R.call_tir(cls.fused_reshape2_squeeze, (lv106,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv1829: R.Object = kv_cache[8]
            lv1830: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1829, lv105, sinfo_args=(R.Object,))
            lv1831: R.Object = kv_cache[9]
            lv1832: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1831, lv107, sinfo_args=(R.Object,))
            lv1833: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1830, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1834: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1832, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1835 = R.call_tir(cls.reshape3, (lv1833,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1836 = R.call_tir(cls.reshape3, (lv1834,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1838 = R.call_tir(cls.transpose6, (lv1835,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1839 = R.call_tir(cls.transpose6, (lv1836,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv108 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv103, lv1838, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv109 = R.call_tir(cls.fused_softmax_cast1, (lv108,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv1848 = R.call_tir(cls.matmul9, (lv109, lv1839), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv110 = R.call_tir(cls.fused_transpose7_reshape4, (lv1848,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv111: R.Tensor((4096, 512), dtype="uint32") = params[44]
            lv112: R.Tensor((4096, 128), dtype="float16") = params[45]
            lv8_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv111, lv112, lv110, lv7_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv525: R.Tensor((4096,), dtype="float16") = params[51]
            lv1854 = R.call_tir(cls.rms_norm1, (lv8_3, lv525), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv115: R.Tensor((22016, 512), dtype="uint32") = params[46]
            lv116: R.Tensor((22016, 128), dtype="float16") = params[47]
            lv9_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv115, lv116, lv1854), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv118 = R.call_tir(cls.fused_split_silu_multiply, (lv9_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv119: R.Tensor((4096, 1376), dtype="uint32") = params[48]
            lv120: R.Tensor((4096, 344), dtype="float16") = params[49]
            lv9_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv119, lv120, lv118, lv8_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv532: R.Tensor((4096,), dtype="float16") = params[60]
            lv1865 = R.call_tir(cls.rms_norm1, (lv9_2, lv532), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv123: R.Tensor((12288, 512), dtype="uint32") = params[52]
            lv124: R.Tensor((12288, 128), dtype="float16") = params[53]
            lv10_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv123, lv124, lv1865), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv20_1 = R.call_tir(cls.split_rotary, (lv10_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv126: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[0]
            lv127 = R.call_tir(cls.fused_reshape2_transpose5, (lv126,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv128: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[1]
            lv129 = R.call_tir(cls.fused_reshape2_squeeze, (lv128,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv130: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[2]
            lv131 = R.call_tir(cls.fused_reshape2_squeeze, (lv130,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv1879: R.Object = kv_cache[10]
            lv1880: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1879, lv129, sinfo_args=(R.Object,))
            lv1881: R.Object = kv_cache[11]
            lv1882: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1881, lv131, sinfo_args=(R.Object,))
            lv1883: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1880, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1884: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1882, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1885 = R.call_tir(cls.reshape3, (lv1883,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1886 = R.call_tir(cls.reshape3, (lv1884,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1888 = R.call_tir(cls.transpose6, (lv1885,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1889 = R.call_tir(cls.transpose6, (lv1886,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv132 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv127, lv1888, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv133 = R.call_tir(cls.fused_softmax_cast1, (lv132,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv1898 = R.call_tir(cls.matmul9, (lv133, lv1889), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv134 = R.call_tir(cls.fused_transpose7_reshape4, (lv1898,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv135: R.Tensor((4096, 512), dtype="uint32") = params[54]
            lv136: R.Tensor((4096, 128), dtype="float16") = params[55]
            lv10_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv135, lv136, lv134, lv9_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv539: R.Tensor((4096,), dtype="float16") = params[61]
            lv1904 = R.call_tir(cls.rms_norm1, (lv10_2, lv539), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv139: R.Tensor((22016, 512), dtype="uint32") = params[56]
            lv140: R.Tensor((22016, 128), dtype="float16") = params[57]
            lv11_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv139, lv140, lv1904), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv142 = R.call_tir(cls.fused_split_silu_multiply, (lv11_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv143: R.Tensor((4096, 1376), dtype="uint32") = params[58]
            lv144: R.Tensor((4096, 344), dtype="float16") = params[59]
            lv11_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv143, lv144, lv142, lv10_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv546: R.Tensor((4096,), dtype="float16") = params[70]
            lv1915 = R.call_tir(cls.rms_norm1, (lv11_2, lv546), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv147: R.Tensor((12288, 512), dtype="uint32") = params[62]
            lv148: R.Tensor((12288, 128), dtype="float16") = params[63]
            lv12_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv147, lv148, lv1915), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv24_1 = R.call_tir(cls.split_rotary, (lv12_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv150: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[0]
            lv151 = R.call_tir(cls.fused_reshape2_transpose5, (lv150,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv152: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[1]
            lv153 = R.call_tir(cls.fused_reshape2_squeeze, (lv152,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv154: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[2]
            lv155 = R.call_tir(cls.fused_reshape2_squeeze, (lv154,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv1929: R.Object = kv_cache[12]
            lv1930: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1929, lv153, sinfo_args=(R.Object,))
            lv1931: R.Object = kv_cache[13]
            lv1932: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1931, lv155, sinfo_args=(R.Object,))
            lv1933: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1930, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1934: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1932, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1935 = R.call_tir(cls.reshape3, (lv1933,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1936 = R.call_tir(cls.reshape3, (lv1934,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1938 = R.call_tir(cls.transpose6, (lv1935,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1939 = R.call_tir(cls.transpose6, (lv1936,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv156 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv151, lv1938, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv157 = R.call_tir(cls.fused_softmax_cast1, (lv156,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv1948 = R.call_tir(cls.matmul9, (lv157, lv1939), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv158 = R.call_tir(cls.fused_transpose7_reshape4, (lv1948,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv159: R.Tensor((4096, 512), dtype="uint32") = params[64]
            lv160: R.Tensor((4096, 128), dtype="float16") = params[65]
            lv12_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv159, lv160, lv158, lv11_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv553: R.Tensor((4096,), dtype="float16") = params[71]
            lv1954 = R.call_tir(cls.rms_norm1, (lv12_3, lv553), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv163: R.Tensor((22016, 512), dtype="uint32") = params[66]
            lv164: R.Tensor((22016, 128), dtype="float16") = params[67]
            lv13_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv163, lv164, lv1954), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv166 = R.call_tir(cls.fused_split_silu_multiply, (lv13_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv167: R.Tensor((4096, 1376), dtype="uint32") = params[68]
            lv168: R.Tensor((4096, 344), dtype="float16") = params[69]
            lv13_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv167, lv168, lv166, lv12_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv560: R.Tensor((4096,), dtype="float16") = params[80]
            lv1965 = R.call_tir(cls.rms_norm1, (lv13_2, lv560), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv171: R.Tensor((12288, 512), dtype="uint32") = params[72]
            lv172: R.Tensor((12288, 128), dtype="float16") = params[73]
            lv14_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv171, lv172, lv1965), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv28_1 = R.call_tir(cls.split_rotary, (lv14_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv174: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[0]
            lv175 = R.call_tir(cls.fused_reshape2_transpose5, (lv174,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv176: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[1]
            lv177 = R.call_tir(cls.fused_reshape2_squeeze, (lv176,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv178: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[2]
            lv179 = R.call_tir(cls.fused_reshape2_squeeze, (lv178,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv1979: R.Object = kv_cache[14]
            lv1980: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1979, lv177, sinfo_args=(R.Object,))
            lv1981: R.Object = kv_cache[15]
            lv1982: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1981, lv179, sinfo_args=(R.Object,))
            lv1983: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1980, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1984: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1982, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv1985 = R.call_tir(cls.reshape3, (lv1983,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1986 = R.call_tir(cls.reshape3, (lv1984,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1988 = R.call_tir(cls.transpose6, (lv1985,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1989 = R.call_tir(cls.transpose6, (lv1986,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv180 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv175, lv1988, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv181 = R.call_tir(cls.fused_softmax_cast1, (lv180,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv1998 = R.call_tir(cls.matmul9, (lv181, lv1989), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv182 = R.call_tir(cls.fused_transpose7_reshape4, (lv1998,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv183: R.Tensor((4096, 512), dtype="uint32") = params[74]
            lv184: R.Tensor((4096, 128), dtype="float16") = params[75]
            lv14_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv183, lv184, lv182, lv13_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv567: R.Tensor((4096,), dtype="float16") = params[81]
            lv2004 = R.call_tir(cls.rms_norm1, (lv14_2, lv567), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv187: R.Tensor((22016, 512), dtype="uint32") = params[76]
            lv188: R.Tensor((22016, 128), dtype="float16") = params[77]
            lv15_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv187, lv188, lv2004), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv190 = R.call_tir(cls.fused_split_silu_multiply, (lv15_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv191: R.Tensor((4096, 1376), dtype="uint32") = params[78]
            lv192: R.Tensor((4096, 344), dtype="float16") = params[79]
            lv15_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv191, lv192, lv190, lv14_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv574: R.Tensor((4096,), dtype="float16") = params[90]
            lv2015 = R.call_tir(cls.rms_norm1, (lv15_2, lv574), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv195: R.Tensor((12288, 512), dtype="uint32") = params[82]
            lv196: R.Tensor((12288, 128), dtype="float16") = params[83]
            lv16_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv195, lv196, lv2015), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv32_1 = R.call_tir(cls.split_rotary, (lv16_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv198: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[0]
            lv199 = R.call_tir(cls.fused_reshape2_transpose5, (lv198,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv200: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[1]
            lv201 = R.call_tir(cls.fused_reshape2_squeeze, (lv200,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv202: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[2]
            lv203 = R.call_tir(cls.fused_reshape2_squeeze, (lv202,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2029: R.Object = kv_cache[16]
            lv2030: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2029, lv201, sinfo_args=(R.Object,))
            lv2031: R.Object = kv_cache[17]
            lv2032: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2031, lv203, sinfo_args=(R.Object,))
            lv2033: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2030, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2034: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2032, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2035 = R.call_tir(cls.reshape3, (lv2033,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2036 = R.call_tir(cls.reshape3, (lv2034,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2038 = R.call_tir(cls.transpose6, (lv2035,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2039 = R.call_tir(cls.transpose6, (lv2036,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv204 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv199, lv2038, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv205 = R.call_tir(cls.fused_softmax_cast1, (lv204,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2048 = R.call_tir(cls.matmul9, (lv205, lv2039), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv206 = R.call_tir(cls.fused_transpose7_reshape4, (lv2048,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv207: R.Tensor((4096, 512), dtype="uint32") = params[84]
            lv208: R.Tensor((4096, 128), dtype="float16") = params[85]
            lv16_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv207, lv208, lv206, lv15_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv581: R.Tensor((4096,), dtype="float16") = params[91]
            lv2054 = R.call_tir(cls.rms_norm1, (lv16_3, lv581), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv211: R.Tensor((22016, 512), dtype="uint32") = params[86]
            lv212: R.Tensor((22016, 128), dtype="float16") = params[87]
            lv17 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv211, lv212, lv2054), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv214 = R.call_tir(cls.fused_split_silu_multiply, (lv17,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv215: R.Tensor((4096, 1376), dtype="uint32") = params[88]
            lv216: R.Tensor((4096, 344), dtype="float16") = params[89]
            lv17_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv215, lv216, lv214, lv16_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv588: R.Tensor((4096,), dtype="float16") = params[100]
            lv2065 = R.call_tir(cls.rms_norm1, (lv17_1, lv588), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv219: R.Tensor((12288, 512), dtype="uint32") = params[92]
            lv220: R.Tensor((12288, 128), dtype="float16") = params[93]
            lv18 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv219, lv220, lv2065), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv36_1 = R.call_tir(cls.split_rotary, (lv18, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv222: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[0]
            lv223 = R.call_tir(cls.fused_reshape2_transpose5, (lv222,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv224: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[1]
            lv225 = R.call_tir(cls.fused_reshape2_squeeze, (lv224,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv226: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[2]
            lv227 = R.call_tir(cls.fused_reshape2_squeeze, (lv226,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2079: R.Object = kv_cache[18]
            lv2080: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2079, lv225, sinfo_args=(R.Object,))
            lv2081: R.Object = kv_cache[19]
            lv2082: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2081, lv227, sinfo_args=(R.Object,))
            lv2083: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2080, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2084: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2082, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2085 = R.call_tir(cls.reshape3, (lv2083,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2086 = R.call_tir(cls.reshape3, (lv2084,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2088 = R.call_tir(cls.transpose6, (lv2085,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2089 = R.call_tir(cls.transpose6, (lv2086,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv228 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv223, lv2088, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv229 = R.call_tir(cls.fused_softmax_cast1, (lv228,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2098 = R.call_tir(cls.matmul9, (lv229, lv2089), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv230 = R.call_tir(cls.fused_transpose7_reshape4, (lv2098,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv231: R.Tensor((4096, 512), dtype="uint32") = params[94]
            lv232: R.Tensor((4096, 128), dtype="float16") = params[95]
            lv18_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv231, lv232, lv230, lv17_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv595: R.Tensor((4096,), dtype="float16") = params[101]
            lv2104 = R.call_tir(cls.rms_norm1, (lv18_1, lv595), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv235: R.Tensor((22016, 512), dtype="uint32") = params[96]
            lv236: R.Tensor((22016, 128), dtype="float16") = params[97]
            lv19_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv235, lv236, lv2104), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv238 = R.call_tir(cls.fused_split_silu_multiply, (lv19_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv239: R.Tensor((4096, 1376), dtype="uint32") = params[98]
            lv240: R.Tensor((4096, 344), dtype="float16") = params[99]
            lv19_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv239, lv240, lv238, lv18_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv602: R.Tensor((4096,), dtype="float16") = params[110]
            lv2115 = R.call_tir(cls.rms_norm1, (lv19_2, lv602), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv243: R.Tensor((12288, 512), dtype="uint32") = params[102]
            lv244: R.Tensor((12288, 128), dtype="float16") = params[103]
            lv20_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv243, lv244, lv2115), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv40_1 = R.call_tir(cls.split_rotary, (lv20_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv246: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[0]
            lv247 = R.call_tir(cls.fused_reshape2_transpose5, (lv246,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv248: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[1]
            lv249 = R.call_tir(cls.fused_reshape2_squeeze, (lv248,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv250: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[2]
            lv251 = R.call_tir(cls.fused_reshape2_squeeze, (lv250,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2129: R.Object = kv_cache[20]
            lv2130: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2129, lv249, sinfo_args=(R.Object,))
            lv2131: R.Object = kv_cache[21]
            lv2132: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2131, lv251, sinfo_args=(R.Object,))
            lv2133: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2130, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2134: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2132, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2135 = R.call_tir(cls.reshape3, (lv2133,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2136 = R.call_tir(cls.reshape3, (lv2134,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2138 = R.call_tir(cls.transpose6, (lv2135,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2139 = R.call_tir(cls.transpose6, (lv2136,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv252 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv247, lv2138, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv253 = R.call_tir(cls.fused_softmax_cast1, (lv252,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2148 = R.call_tir(cls.matmul9, (lv253, lv2139), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv254 = R.call_tir(cls.fused_transpose7_reshape4, (lv2148,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv255: R.Tensor((4096, 512), dtype="uint32") = params[104]
            lv256: R.Tensor((4096, 128), dtype="float16") = params[105]
            lv20_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv255, lv256, lv254, lv19_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv609: R.Tensor((4096,), dtype="float16") = params[111]
            lv2154 = R.call_tir(cls.rms_norm1, (lv20_3, lv609), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv259: R.Tensor((22016, 512), dtype="uint32") = params[106]
            lv260: R.Tensor((22016, 128), dtype="float16") = params[107]
            lv21 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv259, lv260, lv2154), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv262 = R.call_tir(cls.fused_split_silu_multiply, (lv21,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv263: R.Tensor((4096, 1376), dtype="uint32") = params[108]
            lv264: R.Tensor((4096, 344), dtype="float16") = params[109]
            lv21_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv263, lv264, lv262, lv20_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv616: R.Tensor((4096,), dtype="float16") = params[120]
            lv2165 = R.call_tir(cls.rms_norm1, (lv21_1, lv616), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv267: R.Tensor((12288, 512), dtype="uint32") = params[112]
            lv268: R.Tensor((12288, 128), dtype="float16") = params[113]
            lv22_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv267, lv268, lv2165), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv44_1 = R.call_tir(cls.split_rotary, (lv22_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv270: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[0]
            lv271 = R.call_tir(cls.fused_reshape2_transpose5, (lv270,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv272: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[1]
            lv273 = R.call_tir(cls.fused_reshape2_squeeze, (lv272,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv274: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[2]
            lv275 = R.call_tir(cls.fused_reshape2_squeeze, (lv274,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2179: R.Object = kv_cache[22]
            lv2180: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2179, lv273, sinfo_args=(R.Object,))
            lv2181: R.Object = kv_cache[23]
            lv2182: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2181, lv275, sinfo_args=(R.Object,))
            lv2183: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2180, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2184: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2182, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2185 = R.call_tir(cls.reshape3, (lv2183,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2186 = R.call_tir(cls.reshape3, (lv2184,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2188 = R.call_tir(cls.transpose6, (lv2185,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2189 = R.call_tir(cls.transpose6, (lv2186,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv276 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv271, lv2188, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv277 = R.call_tir(cls.fused_softmax_cast1, (lv276,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2198 = R.call_tir(cls.matmul9, (lv277, lv2189), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv278 = R.call_tir(cls.fused_transpose7_reshape4, (lv2198,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv279: R.Tensor((4096, 512), dtype="uint32") = params[114]
            lv280: R.Tensor((4096, 128), dtype="float16") = params[115]
            lv22_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv279, lv280, lv278, lv21_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv623: R.Tensor((4096,), dtype="float16") = params[121]
            lv2204 = R.call_tir(cls.rms_norm1, (lv22_2, lv623), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv283: R.Tensor((22016, 512), dtype="uint32") = params[116]
            lv284: R.Tensor((22016, 128), dtype="float16") = params[117]
            lv23_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv283, lv284, lv2204), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv286 = R.call_tir(cls.fused_split_silu_multiply, (lv23_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv287: R.Tensor((4096, 1376), dtype="uint32") = params[118]
            lv288: R.Tensor((4096, 344), dtype="float16") = params[119]
            lv23_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv287, lv288, lv286, lv22_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv630: R.Tensor((4096,), dtype="float16") = params[130]
            lv2215 = R.call_tir(cls.rms_norm1, (lv23_2, lv630), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv291: R.Tensor((12288, 512), dtype="uint32") = params[122]
            lv292: R.Tensor((12288, 128), dtype="float16") = params[123]
            lv24_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv291, lv292, lv2215), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv48_1 = R.call_tir(cls.split_rotary, (lv24_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv294: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[0]
            lv295 = R.call_tir(cls.fused_reshape2_transpose5, (lv294,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv296: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[1]
            lv297 = R.call_tir(cls.fused_reshape2_squeeze, (lv296,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv298: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[2]
            lv299 = R.call_tir(cls.fused_reshape2_squeeze, (lv298,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2229: R.Object = kv_cache[24]
            lv2230: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2229, lv297, sinfo_args=(R.Object,))
            lv2231: R.Object = kv_cache[25]
            lv2232: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2231, lv299, sinfo_args=(R.Object,))
            lv2233: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2230, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2234: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2232, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2235 = R.call_tir(cls.reshape3, (lv2233,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2236 = R.call_tir(cls.reshape3, (lv2234,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2238 = R.call_tir(cls.transpose6, (lv2235,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2239 = R.call_tir(cls.transpose6, (lv2236,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv300 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv295, lv2238, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv301 = R.call_tir(cls.fused_softmax_cast1, (lv300,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2248 = R.call_tir(cls.matmul9, (lv301, lv2239), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv302 = R.call_tir(cls.fused_transpose7_reshape4, (lv2248,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv303: R.Tensor((4096, 512), dtype="uint32") = params[124]
            lv304: R.Tensor((4096, 128), dtype="float16") = params[125]
            lv24_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv303, lv304, lv302, lv23_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv637: R.Tensor((4096,), dtype="float16") = params[131]
            lv2254 = R.call_tir(cls.rms_norm1, (lv24_3, lv637), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv307: R.Tensor((22016, 512), dtype="uint32") = params[126]
            lv308: R.Tensor((22016, 128), dtype="float16") = params[127]
            lv25 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv307, lv308, lv2254), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv310 = R.call_tir(cls.fused_split_silu_multiply, (lv25,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv311: R.Tensor((4096, 1376), dtype="uint32") = params[128]
            lv312: R.Tensor((4096, 344), dtype="float16") = params[129]
            lv25_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv311, lv312, lv310, lv24_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv644: R.Tensor((4096,), dtype="float16") = params[140]
            lv2265 = R.call_tir(cls.rms_norm1, (lv25_1, lv644), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv315: R.Tensor((12288, 512), dtype="uint32") = params[132]
            lv316: R.Tensor((12288, 128), dtype="float16") = params[133]
            lv26 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv315, lv316, lv2265), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv52_1 = R.call_tir(cls.split_rotary, (lv26, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv318: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[0]
            lv319 = R.call_tir(cls.fused_reshape2_transpose5, (lv318,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv320: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[1]
            lv321 = R.call_tir(cls.fused_reshape2_squeeze, (lv320,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv322: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[2]
            lv323 = R.call_tir(cls.fused_reshape2_squeeze, (lv322,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2279: R.Object = kv_cache[26]
            lv2280: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2279, lv321, sinfo_args=(R.Object,))
            lv2281: R.Object = kv_cache[27]
            lv2282: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2281, lv323, sinfo_args=(R.Object,))
            lv2283: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2280, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2284: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2282, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2285 = R.call_tir(cls.reshape3, (lv2283,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2286 = R.call_tir(cls.reshape3, (lv2284,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2288 = R.call_tir(cls.transpose6, (lv2285,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2289 = R.call_tir(cls.transpose6, (lv2286,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv324 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv319, lv2288, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv325 = R.call_tir(cls.fused_softmax_cast1, (lv324,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2298 = R.call_tir(cls.matmul9, (lv325, lv2289), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv326 = R.call_tir(cls.fused_transpose7_reshape4, (lv2298,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv327: R.Tensor((4096, 512), dtype="uint32") = params[134]
            lv328: R.Tensor((4096, 128), dtype="float16") = params[135]
            lv26_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv327, lv328, lv326, lv25_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv651: R.Tensor((4096,), dtype="float16") = params[141]
            lv2304 = R.call_tir(cls.rms_norm1, (lv26_1, lv651), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv331: R.Tensor((22016, 512), dtype="uint32") = params[136]
            lv332: R.Tensor((22016, 128), dtype="float16") = params[137]
            lv27_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv331, lv332, lv2304), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv334 = R.call_tir(cls.fused_split_silu_multiply, (lv27_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv335: R.Tensor((4096, 1376), dtype="uint32") = params[138]
            lv336: R.Tensor((4096, 344), dtype="float16") = params[139]
            lv27_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv335, lv336, lv334, lv26_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv658: R.Tensor((4096,), dtype="float16") = params[150]
            lv2315 = R.call_tir(cls.rms_norm1, (lv27_2, lv658), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv339: R.Tensor((12288, 512), dtype="uint32") = params[142]
            lv340: R.Tensor((12288, 128), dtype="float16") = params[143]
            lv28_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv339, lv340, lv2315), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv56_1 = R.call_tir(cls.split_rotary, (lv28_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv342: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[0]
            lv343 = R.call_tir(cls.fused_reshape2_transpose5, (lv342,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv344: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[1]
            lv345 = R.call_tir(cls.fused_reshape2_squeeze, (lv344,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv346: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[2]
            lv347 = R.call_tir(cls.fused_reshape2_squeeze, (lv346,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2329: R.Object = kv_cache[28]
            lv2330: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2329, lv345, sinfo_args=(R.Object,))
            lv2331: R.Object = kv_cache[29]
            lv2332: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2331, lv347, sinfo_args=(R.Object,))
            lv2333: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2330, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2334: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2332, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2335 = R.call_tir(cls.reshape3, (lv2333,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2336 = R.call_tir(cls.reshape3, (lv2334,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2338 = R.call_tir(cls.transpose6, (lv2335,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2339 = R.call_tir(cls.transpose6, (lv2336,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv348 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv343, lv2338, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv349 = R.call_tir(cls.fused_softmax_cast1, (lv348,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2348 = R.call_tir(cls.matmul9, (lv349, lv2339), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv350 = R.call_tir(cls.fused_transpose7_reshape4, (lv2348,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv351: R.Tensor((4096, 512), dtype="uint32") = params[144]
            lv352: R.Tensor((4096, 128), dtype="float16") = params[145]
            lv28_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv351, lv352, lv350, lv27_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv665: R.Tensor((4096,), dtype="float16") = params[151]
            lv2354 = R.call_tir(cls.rms_norm1, (lv28_3, lv665), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv355: R.Tensor((22016, 512), dtype="uint32") = params[146]
            lv356: R.Tensor((22016, 128), dtype="float16") = params[147]
            lv29 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv355, lv356, lv2354), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv358 = R.call_tir(cls.fused_split_silu_multiply, (lv29,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv359: R.Tensor((4096, 1376), dtype="uint32") = params[148]
            lv360: R.Tensor((4096, 344), dtype="float16") = params[149]
            lv29_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv359, lv360, lv358, lv28_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv672: R.Tensor((4096,), dtype="float16") = params[160]
            lv2365 = R.call_tir(cls.rms_norm1, (lv29_1, lv672), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv363: R.Tensor((12288, 512), dtype="uint32") = params[152]
            lv364: R.Tensor((12288, 128), dtype="float16") = params[153]
            lv30_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv363, lv364, lv2365), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv60_1 = R.call_tir(cls.split_rotary, (lv30_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv366: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[0]
            lv367 = R.call_tir(cls.fused_reshape2_transpose5, (lv366,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv368: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[1]
            lv369 = R.call_tir(cls.fused_reshape2_squeeze, (lv368,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv370: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[2]
            lv371 = R.call_tir(cls.fused_reshape2_squeeze, (lv370,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2379: R.Object = kv_cache[30]
            lv2380: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2379, lv369, sinfo_args=(R.Object,))
            lv2381: R.Object = kv_cache[31]
            lv2382: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2381, lv371, sinfo_args=(R.Object,))
            lv2383: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2380, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2384: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2382, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2385 = R.call_tir(cls.reshape3, (lv2383,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2386 = R.call_tir(cls.reshape3, (lv2384,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2388 = R.call_tir(cls.transpose6, (lv2385,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2389 = R.call_tir(cls.transpose6, (lv2386,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv372 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv367, lv2388, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv373 = R.call_tir(cls.fused_softmax_cast1, (lv372,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2398 = R.call_tir(cls.matmul9, (lv373, lv2389), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv374 = R.call_tir(cls.fused_transpose7_reshape4, (lv2398,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv375: R.Tensor((4096, 512), dtype="uint32") = params[154]
            lv376: R.Tensor((4096, 128), dtype="float16") = params[155]
            lv30_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv375, lv376, lv374, lv29_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv679: R.Tensor((4096,), dtype="float16") = params[161]
            lv2404 = R.call_tir(cls.rms_norm1, (lv30_2, lv679), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv379: R.Tensor((22016, 512), dtype="uint32") = params[156]
            lv380: R.Tensor((22016, 128), dtype="float16") = params[157]
            lv31_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv379, lv380, lv2404), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv382 = R.call_tir(cls.fused_split_silu_multiply, (lv31_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv383: R.Tensor((4096, 1376), dtype="uint32") = params[158]
            lv384: R.Tensor((4096, 344), dtype="float16") = params[159]
            lv31_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv383, lv384, lv382, lv30_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv686: R.Tensor((4096,), dtype="float16") = params[170]
            lv2415 = R.call_tir(cls.rms_norm1, (lv31_2, lv686), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv387: R.Tensor((12288, 512), dtype="uint32") = params[162]
            lv388: R.Tensor((12288, 128), dtype="float16") = params[163]
            lv32_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv387, lv388, lv2415), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv64_1 = R.call_tir(cls.split_rotary, (lv32_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv390: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[0]
            lv391 = R.call_tir(cls.fused_reshape2_transpose5, (lv390,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv392: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[1]
            lv393 = R.call_tir(cls.fused_reshape2_squeeze, (lv392,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv394: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[2]
            lv395 = R.call_tir(cls.fused_reshape2_squeeze, (lv394,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2429: R.Object = kv_cache[32]
            lv2430: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2429, lv393, sinfo_args=(R.Object,))
            lv2431: R.Object = kv_cache[33]
            lv2432: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2431, lv395, sinfo_args=(R.Object,))
            lv2433: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2430, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2434: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2432, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2435 = R.call_tir(cls.reshape3, (lv2433,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2436 = R.call_tir(cls.reshape3, (lv2434,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2438 = R.call_tir(cls.transpose6, (lv2435,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2439 = R.call_tir(cls.transpose6, (lv2436,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv396 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv391, lv2438, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv397 = R.call_tir(cls.fused_softmax_cast1, (lv396,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2448 = R.call_tir(cls.matmul9, (lv397, lv2439), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv398 = R.call_tir(cls.fused_transpose7_reshape4, (lv2448,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv399: R.Tensor((4096, 512), dtype="uint32") = params[164]
            lv400: R.Tensor((4096, 128), dtype="float16") = params[165]
            lv32_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv399, lv400, lv398, lv31_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv693: R.Tensor((4096,), dtype="float16") = params[171]
            lv2454 = R.call_tir(cls.rms_norm1, (lv32_3, lv693), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv403: R.Tensor((22016, 512), dtype="uint32") = params[166]
            lv404: R.Tensor((22016, 128), dtype="float16") = params[167]
            lv33_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv403, lv404, lv2454), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv406 = R.call_tir(cls.fused_split_silu_multiply, (lv33_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv407: R.Tensor((4096, 1376), dtype="uint32") = params[168]
            lv408: R.Tensor((4096, 344), dtype="float16") = params[169]
            lv33_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv407, lv408, lv406, lv32_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv700: R.Tensor((4096,), dtype="float16") = params[180]
            lv2465 = R.call_tir(cls.rms_norm1, (lv33_2, lv700), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv411: R.Tensor((12288, 512), dtype="uint32") = params[172]
            lv412: R.Tensor((12288, 128), dtype="float16") = params[173]
            lv34_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv411, lv412, lv2465), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv68_1 = R.call_tir(cls.split_rotary, (lv34_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv414: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[0]
            lv415 = R.call_tir(cls.fused_reshape2_transpose5, (lv414,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv416: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[1]
            lv417 = R.call_tir(cls.fused_reshape2_squeeze, (lv416,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv418: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[2]
            lv419 = R.call_tir(cls.fused_reshape2_squeeze, (lv418,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2479: R.Object = kv_cache[34]
            lv2480: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2479, lv417, sinfo_args=(R.Object,))
            lv2481: R.Object = kv_cache[35]
            lv2482: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2481, lv419, sinfo_args=(R.Object,))
            lv2483: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2480, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2484: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2482, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2485 = R.call_tir(cls.reshape3, (lv2483,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2486 = R.call_tir(cls.reshape3, (lv2484,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2488 = R.call_tir(cls.transpose6, (lv2485,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2489 = R.call_tir(cls.transpose6, (lv2486,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv420 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv415, lv2488, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv421 = R.call_tir(cls.fused_softmax_cast1, (lv420,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2498 = R.call_tir(cls.matmul9, (lv421, lv2489), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv422 = R.call_tir(cls.fused_transpose7_reshape4, (lv2498,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv423: R.Tensor((4096, 512), dtype="uint32") = params[174]
            lv424: R.Tensor((4096, 128), dtype="float16") = params[175]
            lv34_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv423, lv424, lv422, lv33_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv707: R.Tensor((4096,), dtype="float16") = params[181]
            lv2504 = R.call_tir(cls.rms_norm1, (lv34_2, lv707), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv427: R.Tensor((22016, 512), dtype="uint32") = params[176]
            lv428: R.Tensor((22016, 128), dtype="float16") = params[177]
            lv35_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv427, lv428, lv2504), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv430 = R.call_tir(cls.fused_split_silu_multiply, (lv35_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv431: R.Tensor((4096, 1376), dtype="uint32") = params[178]
            lv432: R.Tensor((4096, 344), dtype="float16") = params[179]
            lv35_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv431, lv432, lv430, lv34_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv714: R.Tensor((4096,), dtype="float16") = params[190]
            lv2515 = R.call_tir(cls.rms_norm1, (lv35_2, lv714), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv435: R.Tensor((12288, 512), dtype="uint32") = params[182]
            lv436: R.Tensor((12288, 128), dtype="float16") = params[183]
            lv36_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv435, lv436, lv2515), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv72_1 = R.call_tir(cls.split_rotary, (lv36_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv438: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[0]
            lv439 = R.call_tir(cls.fused_reshape2_transpose5, (lv438,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv440: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[1]
            lv441 = R.call_tir(cls.fused_reshape2_squeeze, (lv440,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv442: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[2]
            lv443 = R.call_tir(cls.fused_reshape2_squeeze, (lv442,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2529: R.Object = kv_cache[36]
            lv2530: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2529, lv441, sinfo_args=(R.Object,))
            lv2531: R.Object = kv_cache[37]
            lv2532: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2531, lv443, sinfo_args=(R.Object,))
            lv2533: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2530, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2534: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2532, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2535 = R.call_tir(cls.reshape3, (lv2533,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2536 = R.call_tir(cls.reshape3, (lv2534,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2538 = R.call_tir(cls.transpose6, (lv2535,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2539 = R.call_tir(cls.transpose6, (lv2536,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv444 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv439, lv2538, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv445 = R.call_tir(cls.fused_softmax_cast1, (lv444,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2548 = R.call_tir(cls.matmul9, (lv445, lv2539), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv446 = R.call_tir(cls.fused_transpose7_reshape4, (lv2548,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv447: R.Tensor((4096, 512), dtype="uint32") = params[184]
            lv448: R.Tensor((4096, 128), dtype="float16") = params[185]
            lv36_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv447, lv448, lv446, lv35_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv721: R.Tensor((4096,), dtype="float16") = params[191]
            lv2554 = R.call_tir(cls.rms_norm1, (lv36_3, lv721), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv451: R.Tensor((22016, 512), dtype="uint32") = params[186]
            lv452: R.Tensor((22016, 128), dtype="float16") = params[187]
            lv37_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv451, lv452, lv2554), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv454 = R.call_tir(cls.fused_split_silu_multiply, (lv37_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv455: R.Tensor((4096, 1376), dtype="uint32") = params[188]
            lv456: R.Tensor((4096, 344), dtype="float16") = params[189]
            lv37_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv455, lv456, lv454, lv36_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv728: R.Tensor((4096,), dtype="float16") = params[200]
            lv2565 = R.call_tir(cls.rms_norm1, (lv37_2, lv728), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv459: R.Tensor((12288, 512), dtype="uint32") = params[192]
            lv460_1: R.Tensor((12288, 128), dtype="float16") = params[193]
            lv38_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv459, lv460_1, lv2565), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv76_1 = R.call_tir(cls.split_rotary, (lv38_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv462: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[0]
            lv463 = R.call_tir(cls.fused_reshape2_transpose5, (lv462,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv464_1: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[1]
            lv465_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv464_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv466: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[2]
            lv467 = R.call_tir(cls.fused_reshape2_squeeze, (lv466,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2579: R.Object = kv_cache[38]
            lv2580: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2579, lv465_1, sinfo_args=(R.Object,))
            lv2581: R.Object = kv_cache[39]
            lv2582: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2581, lv467, sinfo_args=(R.Object,))
            lv2583: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2580, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2584: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2582, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2585 = R.call_tir(cls.reshape3, (lv2583,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2586 = R.call_tir(cls.reshape3, (lv2584,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2588 = R.call_tir(cls.transpose6, (lv2585,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2589 = R.call_tir(cls.transpose6, (lv2586,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv468 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv463, lv2588, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv469_1 = R.call_tir(cls.fused_softmax_cast1, (lv468,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2598 = R.call_tir(cls.matmul9, (lv469_1, lv2589), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv470 = R.call_tir(cls.fused_transpose7_reshape4, (lv2598,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv471: R.Tensor((4096, 512), dtype="uint32") = params[194]
            lv472: R.Tensor((4096, 128), dtype="float16") = params[195]
            lv38_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv471, lv472, lv470, lv37_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv735: R.Tensor((4096,), dtype="float16") = params[201]
            lv2604 = R.call_tir(cls.rms_norm1, (lv38_2, lv735), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv475: R.Tensor((22016, 512), dtype="uint32") = params[196]
            lv476_1: R.Tensor((22016, 128), dtype="float16") = params[197]
            lv39_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv475, lv476_1, lv2604), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv478 = R.call_tir(cls.fused_split_silu_multiply, (lv39_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv479: R.Tensor((4096, 1376), dtype="uint32") = params[198]
            lv480: R.Tensor((4096, 344), dtype="float16") = params[199]
            lv39_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv479, lv480, lv478, lv38_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv742: R.Tensor((4096,), dtype="float16") = params[210]
            lv2615 = R.call_tir(cls.rms_norm1, (lv39_2, lv742), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv483_1: R.Tensor((12288, 512), dtype="uint32") = params[202]
            lv484: R.Tensor((12288, 128), dtype="float16") = params[203]
            lv40_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv483_1, lv484, lv2615), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv80_1 = R.call_tir(cls.split_rotary, (lv40_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv486: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[0]
            lv487 = R.call_tir(cls.fused_reshape2_transpose5, (lv486,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv488: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[1]
            lv489 = R.call_tir(cls.fused_reshape2_squeeze, (lv488,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv490_1: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[2]
            lv491 = R.call_tir(cls.fused_reshape2_squeeze, (lv490_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2629: R.Object = kv_cache[40]
            lv2630: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2629, lv489, sinfo_args=(R.Object,))
            lv2631: R.Object = kv_cache[41]
            lv2632: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2631, lv491, sinfo_args=(R.Object,))
            lv2633: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2630, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2634: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2632, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2635 = R.call_tir(cls.reshape3, (lv2633,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2636 = R.call_tir(cls.reshape3, (lv2634,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2638 = R.call_tir(cls.transpose6, (lv2635,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2639 = R.call_tir(cls.transpose6, (lv2636,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv492 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv487, lv2638, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv493 = R.call_tir(cls.fused_softmax_cast1, (lv492,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2648 = R.call_tir(cls.matmul9, (lv493, lv2639), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv494 = R.call_tir(cls.fused_transpose7_reshape4, (lv2648,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv495: R.Tensor((4096, 512), dtype="uint32") = params[204]
            lv496: R.Tensor((4096, 128), dtype="float16") = params[205]
            lv40_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv495, lv496, lv494, lv39_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv749: R.Tensor((4096,), dtype="float16") = params[211]
            lv2654 = R.call_tir(cls.rms_norm1, (lv40_3, lv749), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv499: R.Tensor((22016, 512), dtype="uint32") = params[206]
            lv500: R.Tensor((22016, 128), dtype="float16") = params[207]
            lv41 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv499, lv500, lv2654), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv502 = R.call_tir(cls.fused_split_silu_multiply, (lv41,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv503: R.Tensor((4096, 1376), dtype="uint32") = params[208]
            lv504_1: R.Tensor((4096, 344), dtype="float16") = params[209]
            lv41_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv503, lv504_1, lv502, lv40_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv756: R.Tensor((4096,), dtype="float16") = params[220]
            lv2665 = R.call_tir(cls.rms_norm1, (lv41_1, lv756), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv507: R.Tensor((12288, 512), dtype="uint32") = params[212]
            lv508: R.Tensor((12288, 128), dtype="float16") = params[213]
            lv42 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv507, lv508, lv2665), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv84_1 = R.call_tir(cls.split_rotary, (lv42, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv510: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[0]
            lv511_1 = R.call_tir(cls.fused_reshape2_transpose5, (lv510,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv512: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[1]
            lv513 = R.call_tir(cls.fused_reshape2_squeeze, (lv512,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv514: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[2]
            lv515 = R.call_tir(cls.fused_reshape2_squeeze, (lv514,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2679: R.Object = kv_cache[42]
            lv2680: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2679, lv513, sinfo_args=(R.Object,))
            lv2681: R.Object = kv_cache[43]
            lv2682: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2681, lv515, sinfo_args=(R.Object,))
            lv2683: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2680, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2684: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2682, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2685 = R.call_tir(cls.reshape3, (lv2683,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2686 = R.call_tir(cls.reshape3, (lv2684,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2688 = R.call_tir(cls.transpose6, (lv2685,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2689 = R.call_tir(cls.transpose6, (lv2686,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv516 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv511_1, lv2688, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv517 = R.call_tir(cls.fused_softmax_cast1, (lv516,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2698 = R.call_tir(cls.matmul9, (lv517, lv2689), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv518_1 = R.call_tir(cls.fused_transpose7_reshape4, (lv2698,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv519: R.Tensor((4096, 512), dtype="uint32") = params[214]
            lv520: R.Tensor((4096, 128), dtype="float16") = params[215]
            lv42_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv519, lv520, lv518_1, lv41_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv763: R.Tensor((4096,), dtype="float16") = params[221]
            lv2704 = R.call_tir(cls.rms_norm1, (lv42_1, lv763), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv523: R.Tensor((22016, 512), dtype="uint32") = params[216]
            lv524: R.Tensor((22016, 128), dtype="float16") = params[217]
            lv43_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv523, lv524, lv2704), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv526 = R.call_tir(cls.fused_split_silu_multiply, (lv43_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv527: R.Tensor((4096, 1376), dtype="uint32") = params[218]
            lv528: R.Tensor((4096, 344), dtype="float16") = params[219]
            lv43_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv527, lv528, lv526, lv42_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv770: R.Tensor((4096,), dtype="float16") = params[230]
            lv2715 = R.call_tir(cls.rms_norm1, (lv43_2, lv770), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv531: R.Tensor((12288, 512), dtype="uint32") = params[222]
            lv532_1: R.Tensor((12288, 128), dtype="float16") = params[223]
            lv44_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv531, lv532_1, lv2715), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv88_1 = R.call_tir(cls.split_rotary, (lv44_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv534: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[0]
            lv535 = R.call_tir(cls.fused_reshape2_transpose5, (lv534,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv536: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[1]
            lv537 = R.call_tir(cls.fused_reshape2_squeeze, (lv536,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv538: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[2]
            lv539_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv538,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2729: R.Object = kv_cache[44]
            lv2730: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2729, lv537, sinfo_args=(R.Object,))
            lv2731: R.Object = kv_cache[45]
            lv2732: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2731, lv539_1, sinfo_args=(R.Object,))
            lv2733: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2730, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2734: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2732, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2735 = R.call_tir(cls.reshape3, (lv2733,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2736 = R.call_tir(cls.reshape3, (lv2734,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2738 = R.call_tir(cls.transpose6, (lv2735,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2739 = R.call_tir(cls.transpose6, (lv2736,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv540 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv535, lv2738, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv541 = R.call_tir(cls.fused_softmax_cast1, (lv540,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2748 = R.call_tir(cls.matmul9, (lv541, lv2739), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv542 = R.call_tir(cls.fused_transpose7_reshape4, (lv2748,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv543: R.Tensor((4096, 512), dtype="uint32") = params[224]
            lv544: R.Tensor((4096, 128), dtype="float16") = params[225]
            lv44_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv543, lv544, lv542, lv43_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv777: R.Tensor((4096,), dtype="float16") = params[231]
            lv2754 = R.call_tir(cls.rms_norm1, (lv44_3, lv777), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv547: R.Tensor((22016, 512), dtype="uint32") = params[226]
            lv548: R.Tensor((22016, 128), dtype="float16") = params[227]
            lv45 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv547, lv548, lv2754), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv550 = R.call_tir(cls.fused_split_silu_multiply, (lv45,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv551: R.Tensor((4096, 1376), dtype="uint32") = params[228]
            lv552: R.Tensor((4096, 344), dtype="float16") = params[229]
            lv45_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv551, lv552, lv550, lv44_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv784: R.Tensor((4096,), dtype="float16") = params[240]
            lv2765 = R.call_tir(cls.rms_norm1, (lv45_1, lv784), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv555: R.Tensor((12288, 512), dtype="uint32") = params[232]
            lv556: R.Tensor((12288, 128), dtype="float16") = params[233]
            lv46_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv555, lv556, lv2765), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv92_1 = R.call_tir(cls.split_rotary, (lv46_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv558: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[0]
            lv559 = R.call_tir(cls.fused_reshape2_transpose5, (lv558,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv560_1: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[1]
            lv561 = R.call_tir(cls.fused_reshape2_squeeze, (lv560_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv562: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[2]
            lv563 = R.call_tir(cls.fused_reshape2_squeeze, (lv562,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2779: R.Object = kv_cache[46]
            lv2780: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2779, lv561, sinfo_args=(R.Object,))
            lv2781: R.Object = kv_cache[47]
            lv2782: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2781, lv563, sinfo_args=(R.Object,))
            lv2783: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2780, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2784: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2782, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2785 = R.call_tir(cls.reshape3, (lv2783,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2786 = R.call_tir(cls.reshape3, (lv2784,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2788 = R.call_tir(cls.transpose6, (lv2785,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2789 = R.call_tir(cls.transpose6, (lv2786,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv564 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv559, lv2788, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv565 = R.call_tir(cls.fused_softmax_cast1, (lv564,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2798 = R.call_tir(cls.matmul9, (lv565, lv2789), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv566 = R.call_tir(cls.fused_transpose7_reshape4, (lv2798,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv567_1: R.Tensor((4096, 512), dtype="uint32") = params[234]
            lv568: R.Tensor((4096, 128), dtype="float16") = params[235]
            lv46_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv567_1, lv568, lv566, lv45_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv791: R.Tensor((4096,), dtype="float16") = params[241]
            lv2804 = R.call_tir(cls.rms_norm1, (lv46_2, lv791), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv571: R.Tensor((22016, 512), dtype="uint32") = params[236]
            lv572: R.Tensor((22016, 128), dtype="float16") = params[237]
            lv47_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv571, lv572, lv2804), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv574_1 = R.call_tir(cls.fused_split_silu_multiply, (lv47_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv575: R.Tensor((4096, 1376), dtype="uint32") = params[238]
            lv576: R.Tensor((4096, 344), dtype="float16") = params[239]
            lv47_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv575, lv576, lv574_1, lv46_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv798: R.Tensor((4096,), dtype="float16") = params[250]
            lv2815 = R.call_tir(cls.rms_norm1, (lv47_2, lv798), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv579: R.Tensor((12288, 512), dtype="uint32") = params[242]
            lv580: R.Tensor((12288, 128), dtype="float16") = params[243]
            lv48_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv579, lv580, lv2815), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv96_1 = R.call_tir(cls.split_rotary, (lv48_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv582: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[0]
            lv583 = R.call_tir(cls.fused_reshape2_transpose5, (lv582,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv584: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[1]
            lv585 = R.call_tir(cls.fused_reshape2_squeeze, (lv584,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv586: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[2]
            lv587 = R.call_tir(cls.fused_reshape2_squeeze, (lv586,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2829: R.Object = kv_cache[48]
            lv2830: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2829, lv585, sinfo_args=(R.Object,))
            lv2831: R.Object = kv_cache[49]
            lv2832: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2831, lv587, sinfo_args=(R.Object,))
            lv2833: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2830, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2834: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2832, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2835 = R.call_tir(cls.reshape3, (lv2833,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2836 = R.call_tir(cls.reshape3, (lv2834,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2838 = R.call_tir(cls.transpose6, (lv2835,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2839 = R.call_tir(cls.transpose6, (lv2836,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv588_1 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv583, lv2838, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv589 = R.call_tir(cls.fused_softmax_cast1, (lv588_1,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2848 = R.call_tir(cls.matmul9, (lv589, lv2839), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv590 = R.call_tir(cls.fused_transpose7_reshape4, (lv2848,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv591: R.Tensor((4096, 512), dtype="uint32") = params[244]
            lv592: R.Tensor((4096, 128), dtype="float16") = params[245]
            lv48_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv591, lv592, lv590, lv47_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv805: R.Tensor((4096,), dtype="float16") = params[251]
            lv2854 = R.call_tir(cls.rms_norm1, (lv48_3, lv805), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv595_1: R.Tensor((22016, 512), dtype="uint32") = params[246]
            lv596: R.Tensor((22016, 128), dtype="float16") = params[247]
            lv49 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv595_1, lv596, lv2854), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv598 = R.call_tir(cls.fused_split_silu_multiply, (lv49,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv599: R.Tensor((4096, 1376), dtype="uint32") = params[248]
            lv600: R.Tensor((4096, 344), dtype="float16") = params[249]
            lv49_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv599, lv600, lv598, lv48_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv812: R.Tensor((4096,), dtype="float16") = params[260]
            lv2865 = R.call_tir(cls.rms_norm1, (lv49_1, lv812), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv603: R.Tensor((12288, 512), dtype="uint32") = params[252]
            lv604: R.Tensor((12288, 128), dtype="float16") = params[253]
            lv50 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv603, lv604, lv2865), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv100_1 = R.call_tir(cls.split_rotary, (lv50, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv606: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[0]
            lv607 = R.call_tir(cls.fused_reshape2_transpose5, (lv606,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv608: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[1]
            lv609_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv608,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv610: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[2]
            lv611 = R.call_tir(cls.fused_reshape2_squeeze, (lv610,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2879: R.Object = kv_cache[50]
            lv2880: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2879, lv609_1, sinfo_args=(R.Object,))
            lv2881: R.Object = kv_cache[51]
            lv2882: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2881, lv611, sinfo_args=(R.Object,))
            lv2883: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2880, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2884: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2882, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2885 = R.call_tir(cls.reshape3, (lv2883,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2886 = R.call_tir(cls.reshape3, (lv2884,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2888 = R.call_tir(cls.transpose6, (lv2885,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2889 = R.call_tir(cls.transpose6, (lv2886,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv612 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv607, lv2888, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv613 = R.call_tir(cls.fused_softmax_cast1, (lv612,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2898 = R.call_tir(cls.matmul9, (lv613, lv2889), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv614 = R.call_tir(cls.fused_transpose7_reshape4, (lv2898,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv615: R.Tensor((4096, 512), dtype="uint32") = params[254]
            lv616_1: R.Tensor((4096, 128), dtype="float16") = params[255]
            lv50_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv615, lv616_1, lv614, lv49_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv819: R.Tensor((4096,), dtype="float16") = params[261]
            lv2904 = R.call_tir(cls.rms_norm1, (lv50_1, lv819), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv619: R.Tensor((22016, 512), dtype="uint32") = params[256]
            lv620: R.Tensor((22016, 128), dtype="float16") = params[257]
            lv51_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv619, lv620, lv2904), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv622 = R.call_tir(cls.fused_split_silu_multiply, (lv51_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv623_1: R.Tensor((4096, 1376), dtype="uint32") = params[258]
            lv624: R.Tensor((4096, 344), dtype="float16") = params[259]
            lv51_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv623_1, lv624, lv622, lv50_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv826: R.Tensor((4096,), dtype="float16") = params[270]
            lv2915 = R.call_tir(cls.rms_norm1, (lv51_2, lv826), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv627: R.Tensor((12288, 512), dtype="uint32") = params[262]
            lv628: R.Tensor((12288, 128), dtype="float16") = params[263]
            lv52_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv627, lv628, lv2915), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv104_1 = R.call_tir(cls.split_rotary, (lv52_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv630_1: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[0]
            lv631 = R.call_tir(cls.fused_reshape2_transpose5, (lv630_1,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv632: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[1]
            lv633 = R.call_tir(cls.fused_reshape2_squeeze, (lv632,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv634: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[2]
            lv635 = R.call_tir(cls.fused_reshape2_squeeze, (lv634,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2929: R.Object = kv_cache[52]
            lv2930: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2929, lv633, sinfo_args=(R.Object,))
            lv2931: R.Object = kv_cache[53]
            lv2932: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2931, lv635, sinfo_args=(R.Object,))
            lv2933: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2930, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2934: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2932, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2935 = R.call_tir(cls.reshape3, (lv2933,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2936 = R.call_tir(cls.reshape3, (lv2934,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2938 = R.call_tir(cls.transpose6, (lv2935,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2939 = R.call_tir(cls.transpose6, (lv2936,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv636 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv631, lv2938, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv637_1 = R.call_tir(cls.fused_softmax_cast1, (lv636,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2948 = R.call_tir(cls.matmul9, (lv637_1, lv2939), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv638 = R.call_tir(cls.fused_transpose7_reshape4, (lv2948,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv639: R.Tensor((4096, 512), dtype="uint32") = params[264]
            lv640: R.Tensor((4096, 128), dtype="float16") = params[265]
            lv52_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv639, lv640, lv638, lv51_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv833: R.Tensor((4096,), dtype="float16") = params[271]
            lv2954 = R.call_tir(cls.rms_norm1, (lv52_3, lv833), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv643: R.Tensor((22016, 512), dtype="uint32") = params[266]
            lv644_1: R.Tensor((22016, 128), dtype="float16") = params[267]
            lv53 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv643, lv644_1, lv2954), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv646 = R.call_tir(cls.fused_split_silu_multiply, (lv53,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv647: R.Tensor((4096, 1376), dtype="uint32") = params[268]
            lv648: R.Tensor((4096, 344), dtype="float16") = params[269]
            lv53_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv647, lv648, lv646, lv52_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv840: R.Tensor((4096,), dtype="float16") = params[280]
            lv2965 = R.call_tir(cls.rms_norm1, (lv53_1, lv840), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv651_1: R.Tensor((12288, 512), dtype="uint32") = params[272]
            lv652: R.Tensor((12288, 128), dtype="float16") = params[273]
            lv54_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv651_1, lv652, lv2965), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv108_1 = R.call_tir(cls.split_rotary, (lv54_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv654: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[0]
            lv655 = R.call_tir(cls.fused_reshape2_transpose5, (lv654,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv656: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[1]
            lv657 = R.call_tir(cls.fused_reshape2_squeeze, (lv656,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv658_1: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[2]
            lv659 = R.call_tir(cls.fused_reshape2_squeeze, (lv658_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv2979: R.Object = kv_cache[54]
            lv2980: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2979, lv657, sinfo_args=(R.Object,))
            lv2981: R.Object = kv_cache[55]
            lv2982: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2981, lv659, sinfo_args=(R.Object,))
            lv2983: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2980, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2984: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2982, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv2985 = R.call_tir(cls.reshape3, (lv2983,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2986 = R.call_tir(cls.reshape3, (lv2984,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv2988 = R.call_tir(cls.transpose6, (lv2985,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv2989 = R.call_tir(cls.transpose6, (lv2986,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv660 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv655, lv2988, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv661 = R.call_tir(cls.fused_softmax_cast1, (lv660,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv2998 = R.call_tir(cls.matmul9, (lv661, lv2989), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv662 = R.call_tir(cls.fused_transpose7_reshape4, (lv2998,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv663: R.Tensor((4096, 512), dtype="uint32") = params[274]
            lv664: R.Tensor((4096, 128), dtype="float16") = params[275]
            lv54_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv663, lv664, lv662, lv53_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv847: R.Tensor((4096,), dtype="float16") = params[281]
            lv3004 = R.call_tir(cls.rms_norm1, (lv54_2, lv847), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv667: R.Tensor((22016, 512), dtype="uint32") = params[276]
            lv668: R.Tensor((22016, 128), dtype="float16") = params[277]
            lv55_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv667, lv668, lv3004), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv670 = R.call_tir(cls.fused_split_silu_multiply, (lv55_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv671: R.Tensor((4096, 1376), dtype="uint32") = params[278]
            lv672_1: R.Tensor((4096, 344), dtype="float16") = params[279]
            lv55_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv671, lv672_1, lv670, lv54_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv854: R.Tensor((4096,), dtype="float16") = params[290]
            lv3015 = R.call_tir(cls.rms_norm1, (lv55_2, lv854), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv675: R.Tensor((12288, 512), dtype="uint32") = params[282]
            lv676: R.Tensor((12288, 128), dtype="float16") = params[283]
            lv56_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv675, lv676, lv3015), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv112_1 = R.call_tir(cls.split_rotary, (lv56_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv678: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[0]
            lv679_1 = R.call_tir(cls.fused_reshape2_transpose5, (lv678,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv680: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[1]
            lv681 = R.call_tir(cls.fused_reshape2_squeeze, (lv680,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv682: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[2]
            lv683 = R.call_tir(cls.fused_reshape2_squeeze, (lv682,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv3029: R.Object = kv_cache[56]
            lv3030: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3029, lv681, sinfo_args=(R.Object,))
            lv3031: R.Object = kv_cache[57]
            lv3032: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3031, lv683, sinfo_args=(R.Object,))
            lv3033: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3030, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv3034: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3032, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv3035 = R.call_tir(cls.reshape3, (lv3033,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv3036 = R.call_tir(cls.reshape3, (lv3034,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv3038 = R.call_tir(cls.transpose6, (lv3035,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv3039 = R.call_tir(cls.transpose6, (lv3036,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv684 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv679_1, lv3038, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv685 = R.call_tir(cls.fused_softmax_cast1, (lv684,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv3048 = R.call_tir(cls.matmul9, (lv685, lv3039), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv686_1 = R.call_tir(cls.fused_transpose7_reshape4, (lv3048,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv687: R.Tensor((4096, 512), dtype="uint32") = params[284]
            lv688: R.Tensor((4096, 128), dtype="float16") = params[285]
            lv56_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv687, lv688, lv686_1, lv55_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv861: R.Tensor((4096,), dtype="float16") = params[291]
            lv3054 = R.call_tir(cls.rms_norm1, (lv56_3, lv861), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv691: R.Tensor((22016, 512), dtype="uint32") = params[286]
            lv692: R.Tensor((22016, 128), dtype="float16") = params[287]
            lv57_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv691, lv692, lv3054), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv694 = R.call_tir(cls.fused_split_silu_multiply, (lv57_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv695: R.Tensor((4096, 1376), dtype="uint32") = params[288]
            lv696: R.Tensor((4096, 344), dtype="float16") = params[289]
            lv57_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv695, lv696, lv694, lv56_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv868: R.Tensor((4096,), dtype="float16") = params[300]
            lv3065 = R.call_tir(cls.rms_norm1, (lv57_2, lv868), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv699: R.Tensor((12288, 512), dtype="uint32") = params[292]
            lv700_1: R.Tensor((12288, 128), dtype="float16") = params[293]
            lv58_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv699, lv700_1, lv3065), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv116_1 = R.call_tir(cls.split_rotary, (lv58_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv702: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[0]
            lv703 = R.call_tir(cls.fused_reshape2_transpose5, (lv702,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv704: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[1]
            lv705 = R.call_tir(cls.fused_reshape2_squeeze, (lv704,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv706: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[2]
            lv707_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv706,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv3079: R.Object = kv_cache[58]
            lv3080: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3079, lv705, sinfo_args=(R.Object,))
            lv3081: R.Object = kv_cache[59]
            lv3082: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3081, lv707_1, sinfo_args=(R.Object,))
            lv3083: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3080, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv3084: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3082, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv3085 = R.call_tir(cls.reshape3, (lv3083,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv3086 = R.call_tir(cls.reshape3, (lv3084,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv3088 = R.call_tir(cls.transpose6, (lv3085,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv3089 = R.call_tir(cls.transpose6, (lv3086,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv708 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv703, lv3088, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv709 = R.call_tir(cls.fused_softmax_cast1, (lv708,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv3098 = R.call_tir(cls.matmul9, (lv709, lv3089), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv710 = R.call_tir(cls.fused_transpose7_reshape4, (lv3098,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv711: R.Tensor((4096, 512), dtype="uint32") = params[294]
            lv712: R.Tensor((4096, 128), dtype="float16") = params[295]
            lv58_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv711, lv712, lv710, lv57_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv875: R.Tensor((4096,), dtype="float16") = params[301]
            lv3104 = R.call_tir(cls.rms_norm1, (lv58_2, lv875), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv715: R.Tensor((22016, 512), dtype="uint32") = params[296]
            lv716: R.Tensor((22016, 128), dtype="float16") = params[297]
            lv59_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv715, lv716, lv3104), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv718 = R.call_tir(cls.fused_split_silu_multiply, (lv59_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv719: R.Tensor((4096, 1376), dtype="uint32") = params[298]
            lv720: R.Tensor((4096, 344), dtype="float16") = params[299]
            lv59_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv719, lv720, lv718, lv58_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv882: R.Tensor((4096,), dtype="float16") = params[310]
            lv3115 = R.call_tir(cls.rms_norm1, (lv59_2, lv882), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv723: R.Tensor((12288, 512), dtype="uint32") = params[302]
            lv724: R.Tensor((12288, 128), dtype="float16") = params[303]
            lv60_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv723, lv724, lv3115), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv120_1 = R.call_tir(cls.split_rotary, (lv60_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv726: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[0]
            lv727 = R.call_tir(cls.fused_reshape2_transpose5, (lv726,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv728_1: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[1]
            lv729 = R.call_tir(cls.fused_reshape2_squeeze, (lv728_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv730: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[2]
            lv731 = R.call_tir(cls.fused_reshape2_squeeze, (lv730,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv3129: R.Object = kv_cache[60]
            lv3130: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3129, lv729, sinfo_args=(R.Object,))
            lv3131: R.Object = kv_cache[61]
            lv3132: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3131, lv731, sinfo_args=(R.Object,))
            lv3133: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3130, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv3134: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3132, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv3135 = R.call_tir(cls.reshape3, (lv3133,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv3136 = R.call_tir(cls.reshape3, (lv3134,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv3138 = R.call_tir(cls.transpose6, (lv3135,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv3139 = R.call_tir(cls.transpose6, (lv3136,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv732 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv727, lv3138, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv733 = R.call_tir(cls.fused_softmax_cast1, (lv732,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv3148 = R.call_tir(cls.matmul9, (lv733, lv3139), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv734 = R.call_tir(cls.fused_transpose7_reshape4, (lv3148,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv735_1: R.Tensor((4096, 512), dtype="uint32") = params[304]
            lv736: R.Tensor((4096, 128), dtype="float16") = params[305]
            lv60_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv735_1, lv736, lv734, lv59_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv889: R.Tensor((4096,), dtype="float16") = params[311]
            lv3154 = R.call_tir(cls.rms_norm1, (lv60_3, lv889), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv739: R.Tensor((22016, 512), dtype="uint32") = params[306]
            lv740: R.Tensor((22016, 128), dtype="float16") = params[307]
            lv61_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv739, lv740, lv3154), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv742_1 = R.call_tir(cls.fused_split_silu_multiply, (lv61_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv743: R.Tensor((4096, 1376), dtype="uint32") = params[308]
            lv744: R.Tensor((4096, 344), dtype="float16") = params[309]
            lv61_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv743, lv744, lv742_1, lv60_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv896: R.Tensor((4096,), dtype="float16") = params[320]
            lv3165 = R.call_tir(cls.rms_norm1, (lv61_2, lv896), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv747: R.Tensor((12288, 512), dtype="uint32") = params[312]
            lv748: R.Tensor((12288, 128), dtype="float16") = params[313]
            lv62_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv747, lv748, lv3165), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
            lv124_1 = R.call_tir(cls.split_rotary, (lv62_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
            lv750: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[0]
            lv751 = R.call_tir(cls.fused_reshape2_transpose5, (lv750,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv752: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[1]
            lv753 = R.call_tir(cls.fused_reshape2_squeeze, (lv752,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv754: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[2]
            lv755 = R.call_tir(cls.fused_reshape2_squeeze, (lv754,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
            lv3179: R.Object = kv_cache[62]
            lv3180: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3179, lv753, sinfo_args=(R.Object,))
            lv3181: R.Object = kv_cache[63]
            lv3182: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3181, lv755, sinfo_args=(R.Object,))
            lv3183: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3180, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv3184: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3182, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
            lv3185 = R.call_tir(cls.reshape3, (lv3183,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv3186 = R.call_tir(cls.reshape3, (lv3184,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv3188 = R.call_tir(cls.transpose6, (lv3185,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv3189 = R.call_tir(cls.transpose6, (lv3186,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv756_1 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv751, lv3188, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
            lv757 = R.call_tir(cls.fused_softmax_cast1, (lv756_1,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
            lv3198 = R.call_tir(cls.matmul9, (lv757, lv3189), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
            lv758 = R.call_tir(cls.fused_transpose7_reshape4, (lv3198,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv759: R.Tensor((4096, 512), dtype="uint32") = params[314]
            lv760: R.Tensor((4096, 128), dtype="float16") = params[315]
            lv62_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv759, lv760, lv758, lv61_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv903: R.Tensor((4096,), dtype="float16") = params[321]
            lv3204 = R.call_tir(cls.rms_norm1, (lv62_2, lv903), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv763_1: R.Tensor((22016, 512), dtype="uint32") = params[316]
            lv764: R.Tensor((22016, 128), dtype="float16") = params[317]
            lv63_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv763_1, lv764, lv3204), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv766 = R.call_tir(cls.fused_split_silu_multiply, (lv63_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv767: R.Tensor((4096, 1376), dtype="uint32") = params[318]
            lv768: R.Tensor((4096, 344), dtype="float16") = params[319]
            lv63_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv767, lv768, lv766, lv62_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv910: R.Tensor((4096,), dtype="float16") = params[322]
            lv3215 = R.call_tir(cls.rms_norm1, (lv63_2, lv910), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv3216 = R.call_tir(cls.slice1, (lv3215,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv771: R.Tensor((32001, 512), dtype="uint32") = params[323]
            lv772: R.Tensor((32001, 128), dtype="float16") = params[324]
            lv64_2 = R.call_tir(cls.fused_fused_decode1_fused_NT_matmul5_cast2, (lv771, lv772, lv3216), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32"))
            gv1: R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)) = lv64_2, (lv1630, lv1632, lv1680, lv1682, lv1730, lv1732, lv1780, lv1782, lv1830, lv1832, lv1880, lv1882, lv1930, lv1932, lv1980, lv1982, lv2030, lv2032, lv2080, lv2082, lv2130, lv2132, lv2180, lv2182, lv2230, lv2232, lv2280, lv2282, lv2330, lv2332, lv2380, lv2382, lv2430, lv2432, lv2480, lv2482, lv2530, lv2532, lv2580, lv2582, lv2630, lv2632, lv2680, lv2682, lv2730, lv2732, lv2780, lv2782, lv2830, lv2832, lv2880, lv2882, lv2930, lv2932, lv2980, lv2982, lv3030, lv3032, lv3080, lv3082, lv3130, lv3132, lv3180, lv3182)
            R.output(gv1)
        return gv1

    @R.function
    def get_metadata() -> R.Object:
        R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
        return R.str("{\"model_name\": \"WizardMath-7B-V1.0\", \"max_window_size\": 2048, \"stop_tokens\": [2], \"add_prefix_space\": false}")

    @R.function
    def prefill(input_ids: R.Tensor((1, "n"), dtype="int32"), all_seq_len: R.Shape(["m"]), kv_cache: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object), params: R.Tuple(R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)):
        n = T.int64()
        m = T.int64()
        R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.reshape5, (input_ids,), out_sinfo=R.Tensor((n,), dtype="int32"))
            lv775: R.Tensor((32001, 512), dtype="uint32") = params[0]
            lv776: R.Tensor((32001, 128), dtype="float16") = params[1]
            lv_1 = R.call_tir(cls.fused_fused_decode1_take1, (lv775, lv776, lv), out_sinfo=R.Tensor((n, 4096), dtype="float16"))
            lv2 = R.call_tir(cls.reshape6, (lv_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv778 = R.call_tir(cls.fused_min_max_triu_te_broadcast_to, R.tuple(), out_sinfo=R.Tensor((1, 1, n, n), dtype="float16"), tir_vars=R.shape([n]))
            lv5 = R.call_tir(cls.extend_te, (lv778,), out_sinfo=R.Tensor((1, 1, n, m), dtype="float16"))
            lv3: R.Tensor((4096,), dtype="float16") = params[10]
            lv6 = R.call_tir(cls.rms_norm, (lv2, lv3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv779: R.Tensor((12288, 512), dtype="uint32") = params[2]
            lv780: R.Tensor((12288, 128), dtype="float16") = params[3]
            lv65 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv779, lv780, lv6), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv9 = R.call_tir(cls.split1, (lv65,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv10: R.Tensor((1, n, 4096), dtype="float16") = lv9[0]
            lv11 = R.call_tir(cls.reshape7, (lv10,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv12: R.Tensor((1, n, 4096), dtype="float16") = lv9[1]
            lv13 = R.call_tir(cls.reshape7, (lv12,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv14: R.Tensor((1, n, 4096), dtype="float16") = lv9[2]
            lv15 = R.call_tir(cls.reshape7, (lv14,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv7: R.Tensor((2048, 128), dtype="float16") = params[325]
            lv8: R.Tensor((2048, 128), dtype="float16") = params[326]
            lv16 = R.call_tir(cls.rotary_embedding, (lv11, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv17 = R.call_tir(cls.rotary_embedding, (lv13, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv18 = R.call_tir(cls.squeeze1, (lv17,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv19 = R.call_tir(cls.squeeze1, (lv15,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv20: R.Object = kv_cache[0]
            lv21: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,))
            lv22: R.Object = kv_cache[1]
            lv23: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,))
            lv24: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv21, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv25: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv23, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv26 = R.call_tir(cls.reshape3, (lv24,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv27 = R.call_tir(cls.reshape3, (lv25,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv28 = R.call_tir(cls.transpose6, (lv16,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv29 = R.call_tir(cls.transpose6, (lv26,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv30 = R.call_tir(cls.transpose6, (lv27,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv782 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv28, lv29, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv783 = R.call_tir(cls.fused_softmax1_cast4, (lv782,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv39 = R.call_tir(cls.matmul10, (lv783, lv30), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv40 = R.call_tir(cls.transpose8, (lv39,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv41 = R.call_tir(cls.reshape8, (lv40,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv784: R.Tensor((4096, 512), dtype="uint32") = params[4]
            lv785: R.Tensor((4096, 128), dtype="float16") = params[5]
            lv64 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv784, lv785, lv41, lv2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv12_1: R.Tensor((4096,), dtype="float16") = params[11]
            lv45 = R.call_tir(cls.rms_norm, (lv64, lv12_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv788: R.Tensor((22016, 512), dtype="uint32") = params[6]
            lv789: R.Tensor((22016, 128), dtype="float16") = params[7]
            lv66 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv788, lv789, lv45), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv791 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv66,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv792: R.Tensor((4096, 1376), dtype="uint32") = params[8]
            lv793: R.Tensor((4096, 344), dtype="float16") = params[9]
            lv65_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv792, lv793, lv791, lv64), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv19_1: R.Tensor((4096,), dtype="float16") = params[20]
            lv56 = R.call_tir(cls.rms_norm, (lv65_1, lv19_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv796: R.Tensor((12288, 512), dtype="uint32") = params[12]
            lv797: R.Tensor((12288, 128), dtype="float16") = params[13]
            lv67 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv796, lv797, lv56), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv59 = R.call_tir(cls.split1, (lv67,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv60: R.Tensor((1, n, 4096), dtype="float16") = lv59[0]
            lv61 = R.call_tir(cls.reshape7, (lv60,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv62: R.Tensor((1, n, 4096), dtype="float16") = lv59[1]
            lv63 = R.call_tir(cls.reshape7, (lv62,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv64_1: R.Tensor((1, n, 4096), dtype="float16") = lv59[2]
            lv65_2 = R.call_tir(cls.reshape7, (lv64_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv66_1 = R.call_tir(cls.rotary_embedding, (lv61, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv67_1 = R.call_tir(cls.rotary_embedding, (lv63, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv68 = R.call_tir(cls.squeeze1, (lv67_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv69 = R.call_tir(cls.squeeze1, (lv65_2,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv70: R.Object = kv_cache[2]
            lv71: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv70, lv68, sinfo_args=(R.Object,))
            lv72: R.Object = kv_cache[3]
            lv73: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv72, lv69, sinfo_args=(R.Object,))
            lv74: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv71, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv75: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv73, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv76 = R.call_tir(cls.reshape3, (lv74,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv77 = R.call_tir(cls.reshape3, (lv75,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv78 = R.call_tir(cls.transpose6, (lv66_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv79 = R.call_tir(cls.transpose6, (lv76,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv80 = R.call_tir(cls.transpose6, (lv77,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv799 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv78, lv79, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv800 = R.call_tir(cls.fused_softmax1_cast4, (lv799,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv89 = R.call_tir(cls.matmul10, (lv800, lv80), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv90 = R.call_tir(cls.transpose8, (lv89,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv91 = R.call_tir(cls.reshape8, (lv90,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv801: R.Tensor((4096, 512), dtype="uint32") = params[14]
            lv802: R.Tensor((4096, 128), dtype="float16") = params[15]
            lv66_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv801, lv802, lv91, lv65_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv26_1: R.Tensor((4096,), dtype="float16") = params[21]
            lv95 = R.call_tir(cls.rms_norm, (lv66_2, lv26_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv805: R.Tensor((22016, 512), dtype="uint32") = params[16]
            lv806: R.Tensor((22016, 128), dtype="float16") = params[17]
            lv68_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv805, lv806, lv95), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv808 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv68_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv809: R.Tensor((4096, 1376), dtype="uint32") = params[18]
            lv810: R.Tensor((4096, 344), dtype="float16") = params[19]
            lv67_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv809, lv810, lv808, lv66_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv33: R.Tensor((4096,), dtype="float16") = params[30]
            lv106 = R.call_tir(cls.rms_norm, (lv67_2, lv33), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv813: R.Tensor((12288, 512), dtype="uint32") = params[22]
            lv814: R.Tensor((12288, 128), dtype="float16") = params[23]
            lv69_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv813, lv814, lv106), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv109 = R.call_tir(cls.split1, (lv69_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv110: R.Tensor((1, n, 4096), dtype="float16") = lv109[0]
            lv111 = R.call_tir(cls.reshape7, (lv110,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv112: R.Tensor((1, n, 4096), dtype="float16") = lv109[1]
            lv113 = R.call_tir(cls.reshape7, (lv112,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv114: R.Tensor((1, n, 4096), dtype="float16") = lv109[2]
            lv115 = R.call_tir(cls.reshape7, (lv114,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv116 = R.call_tir(cls.rotary_embedding, (lv111, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv117 = R.call_tir(cls.rotary_embedding, (lv113, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv118 = R.call_tir(cls.squeeze1, (lv117,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv119 = R.call_tir(cls.squeeze1, (lv115,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv120: R.Object = kv_cache[4]
            lv121: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv120, lv118, sinfo_args=(R.Object,))
            lv122: R.Object = kv_cache[5]
            lv123: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv122, lv119, sinfo_args=(R.Object,))
            lv124: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv121, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv125: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv123, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv126 = R.call_tir(cls.reshape3, (lv124,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv127 = R.call_tir(cls.reshape3, (lv125,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv128 = R.call_tir(cls.transpose6, (lv116,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv129 = R.call_tir(cls.transpose6, (lv126,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv130 = R.call_tir(cls.transpose6, (lv127,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv816 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv128, lv129, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv817 = R.call_tir(cls.fused_softmax1_cast4, (lv816,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv139 = R.call_tir(cls.matmul10, (lv817, lv130), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv140 = R.call_tir(cls.transpose8, (lv139,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv141 = R.call_tir(cls.reshape8, (lv140,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv818: R.Tensor((4096, 512), dtype="uint32") = params[24]
            lv819: R.Tensor((4096, 128), dtype="float16") = params[25]
            lv68_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv818, lv819, lv141, lv67_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv40_1: R.Tensor((4096,), dtype="float16") = params[31]
            lv145 = R.call_tir(cls.rms_norm, (lv68_2, lv40_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv822: R.Tensor((22016, 512), dtype="uint32") = params[26]
            lv823: R.Tensor((22016, 128), dtype="float16") = params[27]
            lv70_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv822, lv823, lv145), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv825 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv70_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv826: R.Tensor((4096, 1376), dtype="uint32") = params[28]
            lv827: R.Tensor((4096, 344), dtype="float16") = params[29]
            lv69_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv826, lv827, lv825, lv68_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv47: R.Tensor((4096,), dtype="float16") = params[40]
            lv156 = R.call_tir(cls.rms_norm, (lv69_2, lv47), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv830: R.Tensor((12288, 512), dtype="uint32") = params[32]
            lv831: R.Tensor((12288, 128), dtype="float16") = params[33]
            lv71_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv830, lv831, lv156), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv159 = R.call_tir(cls.split1, (lv71_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv160: R.Tensor((1, n, 4096), dtype="float16") = lv159[0]
            lv161 = R.call_tir(cls.reshape7, (lv160,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv162: R.Tensor((1, n, 4096), dtype="float16") = lv159[1]
            lv163 = R.call_tir(cls.reshape7, (lv162,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv164: R.Tensor((1, n, 4096), dtype="float16") = lv159[2]
            lv165 = R.call_tir(cls.reshape7, (lv164,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv166 = R.call_tir(cls.rotary_embedding, (lv161, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv167 = R.call_tir(cls.rotary_embedding, (lv163, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv168 = R.call_tir(cls.squeeze1, (lv167,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv169 = R.call_tir(cls.squeeze1, (lv165,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv170: R.Object = kv_cache[6]
            lv171: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv170, lv168, sinfo_args=(R.Object,))
            lv172: R.Object = kv_cache[7]
            lv173: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv172, lv169, sinfo_args=(R.Object,))
            lv174: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv171, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv175: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv173, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv176 = R.call_tir(cls.reshape3, (lv174,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv177 = R.call_tir(cls.reshape3, (lv175,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv178 = R.call_tir(cls.transpose6, (lv166,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv179 = R.call_tir(cls.transpose6, (lv176,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv180 = R.call_tir(cls.transpose6, (lv177,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv833 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv178, lv179, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv834 = R.call_tir(cls.fused_softmax1_cast4, (lv833,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv189 = R.call_tir(cls.matmul10, (lv834, lv180), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv190 = R.call_tir(cls.transpose8, (lv189,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv191 = R.call_tir(cls.reshape8, (lv190,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv835: R.Tensor((4096, 512), dtype="uint32") = params[34]
            lv836: R.Tensor((4096, 128), dtype="float16") = params[35]
            lv70_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv835, lv836, lv191, lv69_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv54: R.Tensor((4096,), dtype="float16") = params[41]
            lv195 = R.call_tir(cls.rms_norm, (lv70_2, lv54), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv839: R.Tensor((22016, 512), dtype="uint32") = params[36]
            lv840: R.Tensor((22016, 128), dtype="float16") = params[37]
            lv72_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv839, lv840, lv195), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv842 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv72_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv843: R.Tensor((4096, 1376), dtype="uint32") = params[38]
            lv844: R.Tensor((4096, 344), dtype="float16") = params[39]
            lv71_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv843, lv844, lv842, lv70_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv61_1: R.Tensor((4096,), dtype="float16") = params[50]
            lv206 = R.call_tir(cls.rms_norm, (lv71_2, lv61_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv847: R.Tensor((12288, 512), dtype="uint32") = params[42]
            lv848: R.Tensor((12288, 128), dtype="float16") = params[43]
            lv73_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv847, lv848, lv206), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv209 = R.call_tir(cls.split1, (lv73_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv210: R.Tensor((1, n, 4096), dtype="float16") = lv209[0]
            lv211 = R.call_tir(cls.reshape7, (lv210,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv212: R.Tensor((1, n, 4096), dtype="float16") = lv209[1]
            lv213 = R.call_tir(cls.reshape7, (lv212,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv214: R.Tensor((1, n, 4096), dtype="float16") = lv209[2]
            lv215 = R.call_tir(cls.reshape7, (lv214,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv216 = R.call_tir(cls.rotary_embedding, (lv211, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv217 = R.call_tir(cls.rotary_embedding, (lv213, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv218 = R.call_tir(cls.squeeze1, (lv217,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv219 = R.call_tir(cls.squeeze1, (lv215,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv220: R.Object = kv_cache[8]
            lv221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv220, lv218, sinfo_args=(R.Object,))
            lv222: R.Object = kv_cache[9]
            lv223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv222, lv219, sinfo_args=(R.Object,))
            lv224: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv221, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv225: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv223, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv226 = R.call_tir(cls.reshape3, (lv224,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv227 = R.call_tir(cls.reshape3, (lv225,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv228 = R.call_tir(cls.transpose6, (lv216,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv229 = R.call_tir(cls.transpose6, (lv226,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv230 = R.call_tir(cls.transpose6, (lv227,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv850 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv228, lv229, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv851 = R.call_tir(cls.fused_softmax1_cast4, (lv850,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv239 = R.call_tir(cls.matmul10, (lv851, lv230), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv240 = R.call_tir(cls.transpose8, (lv239,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv241 = R.call_tir(cls.reshape8, (lv240,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv852: R.Tensor((4096, 512), dtype="uint32") = params[44]
            lv853: R.Tensor((4096, 128), dtype="float16") = params[45]
            lv72_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv852, lv853, lv241, lv71_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv68_3: R.Tensor((4096,), dtype="float16") = params[51]
            lv245 = R.call_tir(cls.rms_norm, (lv72_2, lv68_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv856: R.Tensor((22016, 512), dtype="uint32") = params[46]
            lv857: R.Tensor((22016, 128), dtype="float16") = params[47]
            lv74_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv856, lv857, lv245), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv859 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv74_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv860: R.Tensor((4096, 1376), dtype="uint32") = params[48]
            lv861: R.Tensor((4096, 344), dtype="float16") = params[49]
            lv73_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv860, lv861, lv859, lv72_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv75_1: R.Tensor((4096,), dtype="float16") = params[60]
            lv256 = R.call_tir(cls.rms_norm, (lv73_2, lv75_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv864: R.Tensor((12288, 512), dtype="uint32") = params[52]
            lv865: R.Tensor((12288, 128), dtype="float16") = params[53]
            lv75_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv864, lv865, lv256), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv259 = R.call_tir(cls.split1, (lv75_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv260: R.Tensor((1, n, 4096), dtype="float16") = lv259[0]
            lv261 = R.call_tir(cls.reshape7, (lv260,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv262: R.Tensor((1, n, 4096), dtype="float16") = lv259[1]
            lv263 = R.call_tir(cls.reshape7, (lv262,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv264: R.Tensor((1, n, 4096), dtype="float16") = lv259[2]
            lv265 = R.call_tir(cls.reshape7, (lv264,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv266 = R.call_tir(cls.rotary_embedding, (lv261, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv267 = R.call_tir(cls.rotary_embedding, (lv263, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv268 = R.call_tir(cls.squeeze1, (lv267,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv269 = R.call_tir(cls.squeeze1, (lv265,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv270: R.Object = kv_cache[10]
            lv271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv270, lv268, sinfo_args=(R.Object,))
            lv272: R.Object = kv_cache[11]
            lv273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv272, lv269, sinfo_args=(R.Object,))
            lv274: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv271, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv275: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv273, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv276 = R.call_tir(cls.reshape3, (lv274,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv277 = R.call_tir(cls.reshape3, (lv275,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv278 = R.call_tir(cls.transpose6, (lv266,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv279 = R.call_tir(cls.transpose6, (lv276,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv280 = R.call_tir(cls.transpose6, (lv277,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv867 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv278, lv279, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv868 = R.call_tir(cls.fused_softmax1_cast4, (lv867,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv289 = R.call_tir(cls.matmul10, (lv868, lv280), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv290 = R.call_tir(cls.transpose8, (lv289,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv291 = R.call_tir(cls.reshape8, (lv290,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv869: R.Tensor((4096, 512), dtype="uint32") = params[54]
            lv870: R.Tensor((4096, 128), dtype="float16") = params[55]
            lv74_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv869, lv870, lv291, lv73_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv82: R.Tensor((4096,), dtype="float16") = params[61]
            lv295 = R.call_tir(cls.rms_norm, (lv74_2, lv82), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv873: R.Tensor((22016, 512), dtype="uint32") = params[56]
            lv874: R.Tensor((22016, 128), dtype="float16") = params[57]
            lv76_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv873, lv874, lv295), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv876 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv76_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv877: R.Tensor((4096, 1376), dtype="uint32") = params[58]
            lv878: R.Tensor((4096, 344), dtype="float16") = params[59]
            lv75_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv877, lv878, lv876, lv74_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv89_1: R.Tensor((4096,), dtype="float16") = params[70]
            lv306 = R.call_tir(cls.rms_norm, (lv75_3, lv89_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv881: R.Tensor((12288, 512), dtype="uint32") = params[62]
            lv882: R.Tensor((12288, 128), dtype="float16") = params[63]
            lv77_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv881, lv882, lv306), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv309 = R.call_tir(cls.split1, (lv77_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv310: R.Tensor((1, n, 4096), dtype="float16") = lv309[0]
            lv311 = R.call_tir(cls.reshape7, (lv310,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv312: R.Tensor((1, n, 4096), dtype="float16") = lv309[1]
            lv313 = R.call_tir(cls.reshape7, (lv312,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv314: R.Tensor((1, n, 4096), dtype="float16") = lv309[2]
            lv315 = R.call_tir(cls.reshape7, (lv314,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv316 = R.call_tir(cls.rotary_embedding, (lv311, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv317 = R.call_tir(cls.rotary_embedding, (lv313, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv318 = R.call_tir(cls.squeeze1, (lv317,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv319 = R.call_tir(cls.squeeze1, (lv315,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv320: R.Object = kv_cache[12]
            lv321: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv320, lv318, sinfo_args=(R.Object,))
            lv322: R.Object = kv_cache[13]
            lv323: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv322, lv319, sinfo_args=(R.Object,))
            lv324: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv321, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv325: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv323, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv326 = R.call_tir(cls.reshape3, (lv324,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv327 = R.call_tir(cls.reshape3, (lv325,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv328 = R.call_tir(cls.transpose6, (lv316,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv329 = R.call_tir(cls.transpose6, (lv326,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv330 = R.call_tir(cls.transpose6, (lv327,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv884 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv328, lv329, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv885 = R.call_tir(cls.fused_softmax1_cast4, (lv884,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv339 = R.call_tir(cls.matmul10, (lv885, lv330), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv340 = R.call_tir(cls.transpose8, (lv339,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv341 = R.call_tir(cls.reshape8, (lv340,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv886: R.Tensor((4096, 512), dtype="uint32") = params[64]
            lv887: R.Tensor((4096, 128), dtype="float16") = params[65]
            lv76_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv886, lv887, lv341, lv75_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv96: R.Tensor((4096,), dtype="float16") = params[71]
            lv345 = R.call_tir(cls.rms_norm, (lv76_2, lv96), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv890: R.Tensor((22016, 512), dtype="uint32") = params[66]
            lv891: R.Tensor((22016, 128), dtype="float16") = params[67]
            lv78_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv890, lv891, lv345), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv893 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv78_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv894: R.Tensor((4096, 1376), dtype="uint32") = params[68]
            lv895: R.Tensor((4096, 344), dtype="float16") = params[69]
            lv77_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv894, lv895, lv893, lv76_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv103: R.Tensor((4096,), dtype="float16") = params[80]
            lv356 = R.call_tir(cls.rms_norm, (lv77_2, lv103), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv898: R.Tensor((12288, 512), dtype="uint32") = params[72]
            lv899: R.Tensor((12288, 128), dtype="float16") = params[73]
            lv79_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv898, lv899, lv356), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv359 = R.call_tir(cls.split1, (lv79_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv360: R.Tensor((1, n, 4096), dtype="float16") = lv359[0]
            lv361 = R.call_tir(cls.reshape7, (lv360,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv362: R.Tensor((1, n, 4096), dtype="float16") = lv359[1]
            lv363 = R.call_tir(cls.reshape7, (lv362,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv364: R.Tensor((1, n, 4096), dtype="float16") = lv359[2]
            lv365 = R.call_tir(cls.reshape7, (lv364,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv366 = R.call_tir(cls.rotary_embedding, (lv361, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv367 = R.call_tir(cls.rotary_embedding, (lv363, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv368 = R.call_tir(cls.squeeze1, (lv367,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv369 = R.call_tir(cls.squeeze1, (lv365,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv370: R.Object = kv_cache[14]
            lv371: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv370, lv368, sinfo_args=(R.Object,))
            lv372: R.Object = kv_cache[15]
            lv373: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv372, lv369, sinfo_args=(R.Object,))
            lv374: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv371, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv375: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv373, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv376 = R.call_tir(cls.reshape3, (lv374,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv377 = R.call_tir(cls.reshape3, (lv375,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv378 = R.call_tir(cls.transpose6, (lv366,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv379 = R.call_tir(cls.transpose6, (lv376,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv380 = R.call_tir(cls.transpose6, (lv377,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv901 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv378, lv379, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv902 = R.call_tir(cls.fused_softmax1_cast4, (lv901,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv389 = R.call_tir(cls.matmul10, (lv902, lv380), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv390 = R.call_tir(cls.transpose8, (lv389,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv391 = R.call_tir(cls.reshape8, (lv390,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv903: R.Tensor((4096, 512), dtype="uint32") = params[74]
            lv904: R.Tensor((4096, 128), dtype="float16") = params[75]
            lv78_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv903, lv904, lv391, lv77_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv110_1: R.Tensor((4096,), dtype="float16") = params[81]
            lv395 = R.call_tir(cls.rms_norm, (lv78_2, lv110_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv907: R.Tensor((22016, 512), dtype="uint32") = params[76]
            lv908: R.Tensor((22016, 128), dtype="float16") = params[77]
            lv80_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv907, lv908, lv395), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv910 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv80_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv911: R.Tensor((4096, 1376), dtype="uint32") = params[78]
            lv912: R.Tensor((4096, 344), dtype="float16") = params[79]
            lv79_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv911, lv912, lv910, lv78_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv117_1: R.Tensor((4096,), dtype="float16") = params[90]
            lv406 = R.call_tir(cls.rms_norm, (lv79_2, lv117_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv915: R.Tensor((12288, 512), dtype="uint32") = params[82]
            lv916: R.Tensor((12288, 128), dtype="float16") = params[83]
            lv81 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv915, lv916, lv406), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv409 = R.call_tir(cls.split1, (lv81,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv410: R.Tensor((1, n, 4096), dtype="float16") = lv409[0]
            lv411 = R.call_tir(cls.reshape7, (lv410,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv412: R.Tensor((1, n, 4096), dtype="float16") = lv409[1]
            lv413 = R.call_tir(cls.reshape7, (lv412,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv414: R.Tensor((1, n, 4096), dtype="float16") = lv409[2]
            lv415 = R.call_tir(cls.reshape7, (lv414,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv416 = R.call_tir(cls.rotary_embedding, (lv411, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv417 = R.call_tir(cls.rotary_embedding, (lv413, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv418 = R.call_tir(cls.squeeze1, (lv417,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv419 = R.call_tir(cls.squeeze1, (lv415,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv420: R.Object = kv_cache[16]
            lv421: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv420, lv418, sinfo_args=(R.Object,))
            lv422: R.Object = kv_cache[17]
            lv423: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv422, lv419, sinfo_args=(R.Object,))
            lv424: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv421, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv425: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv423, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv426 = R.call_tir(cls.reshape3, (lv424,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv427 = R.call_tir(cls.reshape3, (lv425,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv428 = R.call_tir(cls.transpose6, (lv416,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv429 = R.call_tir(cls.transpose6, (lv426,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv430 = R.call_tir(cls.transpose6, (lv427,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv918 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv428, lv429, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv919 = R.call_tir(cls.fused_softmax1_cast4, (lv918,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv439 = R.call_tir(cls.matmul10, (lv919, lv430), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv440 = R.call_tir(cls.transpose8, (lv439,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv441 = R.call_tir(cls.reshape8, (lv440,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv920: R.Tensor((4096, 512), dtype="uint32") = params[84]
            lv921: R.Tensor((4096, 128), dtype="float16") = params[85]
            lv80_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv920, lv921, lv441, lv79_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv124_1: R.Tensor((4096,), dtype="float16") = params[91]
            lv445 = R.call_tir(cls.rms_norm, (lv80_2, lv124_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv924: R.Tensor((22016, 512), dtype="uint32") = params[86]
            lv925: R.Tensor((22016, 128), dtype="float16") = params[87]
            lv82_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv924, lv925, lv445), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv927 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv82_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv928: R.Tensor((4096, 1376), dtype="uint32") = params[88]
            lv929: R.Tensor((4096, 344), dtype="float16") = params[89]
            lv81_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv928, lv929, lv927, lv80_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv131: R.Tensor((4096,), dtype="float16") = params[100]
            lv456 = R.call_tir(cls.rms_norm, (lv81_1, lv131), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv932: R.Tensor((12288, 512), dtype="uint32") = params[92]
            lv933: R.Tensor((12288, 128), dtype="float16") = params[93]
            lv83 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv932, lv933, lv456), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv459 = R.call_tir(cls.split1, (lv83,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv460: R.Tensor((1, n, 4096), dtype="float16") = lv459[0]
            lv461 = R.call_tir(cls.reshape7, (lv460,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv462: R.Tensor((1, n, 4096), dtype="float16") = lv459[1]
            lv463 = R.call_tir(cls.reshape7, (lv462,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv464: R.Tensor((1, n, 4096), dtype="float16") = lv459[2]
            lv465 = R.call_tir(cls.reshape7, (lv464,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv466 = R.call_tir(cls.rotary_embedding, (lv461, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv467 = R.call_tir(cls.rotary_embedding, (lv463, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv468 = R.call_tir(cls.squeeze1, (lv467,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv469 = R.call_tir(cls.squeeze1, (lv465,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv470: R.Object = kv_cache[18]
            lv471: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv470, lv468, sinfo_args=(R.Object,))
            lv472: R.Object = kv_cache[19]
            lv473: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv472, lv469, sinfo_args=(R.Object,))
            lv474: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv471, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv475: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv473, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv476 = R.call_tir(cls.reshape3, (lv474,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv477 = R.call_tir(cls.reshape3, (lv475,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv478 = R.call_tir(cls.transpose6, (lv466,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv479 = R.call_tir(cls.transpose6, (lv476,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv480 = R.call_tir(cls.transpose6, (lv477,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv935 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv478, lv479, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv936 = R.call_tir(cls.fused_softmax1_cast4, (lv935,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv489 = R.call_tir(cls.matmul10, (lv936, lv480), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv490 = R.call_tir(cls.transpose8, (lv489,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv491 = R.call_tir(cls.reshape8, (lv490,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv937: R.Tensor((4096, 512), dtype="uint32") = params[94]
            lv938: R.Tensor((4096, 128), dtype="float16") = params[95]
            lv82_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv937, lv938, lv491, lv81_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv138: R.Tensor((4096,), dtype="float16") = params[101]
            lv495 = R.call_tir(cls.rms_norm, (lv82_2, lv138), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv941: R.Tensor((22016, 512), dtype="uint32") = params[96]
            lv942: R.Tensor((22016, 128), dtype="float16") = params[97]
            lv84 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv941, lv942, lv495), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv944 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv84,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv945: R.Tensor((4096, 1376), dtype="uint32") = params[98]
            lv946: R.Tensor((4096, 344), dtype="float16") = params[99]
            lv83_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv945, lv946, lv944, lv82_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv145_1: R.Tensor((4096,), dtype="float16") = params[110]
            lv506 = R.call_tir(cls.rms_norm, (lv83_1, lv145_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv949: R.Tensor((12288, 512), dtype="uint32") = params[102]
            lv950: R.Tensor((12288, 128), dtype="float16") = params[103]
            lv85 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv949, lv950, lv506), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv509 = R.call_tir(cls.split1, (lv85,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv510: R.Tensor((1, n, 4096), dtype="float16") = lv509[0]
            lv511 = R.call_tir(cls.reshape7, (lv510,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv512: R.Tensor((1, n, 4096), dtype="float16") = lv509[1]
            lv513 = R.call_tir(cls.reshape7, (lv512,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv514: R.Tensor((1, n, 4096), dtype="float16") = lv509[2]
            lv515 = R.call_tir(cls.reshape7, (lv514,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv516 = R.call_tir(cls.rotary_embedding, (lv511, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv517 = R.call_tir(cls.rotary_embedding, (lv513, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv518 = R.call_tir(cls.squeeze1, (lv517,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv519 = R.call_tir(cls.squeeze1, (lv515,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv520: R.Object = kv_cache[20]
            lv521: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv520, lv518, sinfo_args=(R.Object,))
            lv522: R.Object = kv_cache[21]
            lv523: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv522, lv519, sinfo_args=(R.Object,))
            lv524: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv521, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv525: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv523, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv526 = R.call_tir(cls.reshape3, (lv524,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv527 = R.call_tir(cls.reshape3, (lv525,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv528 = R.call_tir(cls.transpose6, (lv516,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv529 = R.call_tir(cls.transpose6, (lv526,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv530 = R.call_tir(cls.transpose6, (lv527,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv952 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv528, lv529, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv953 = R.call_tir(cls.fused_softmax1_cast4, (lv952,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv539 = R.call_tir(cls.matmul10, (lv953, lv530), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv540 = R.call_tir(cls.transpose8, (lv539,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv541 = R.call_tir(cls.reshape8, (lv540,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv954: R.Tensor((4096, 512), dtype="uint32") = params[104]
            lv955: R.Tensor((4096, 128), dtype="float16") = params[105]
            lv84_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv954, lv955, lv541, lv83_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv152: R.Tensor((4096,), dtype="float16") = params[111]
            lv545 = R.call_tir(cls.rms_norm, (lv84_1, lv152), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv958: R.Tensor((22016, 512), dtype="uint32") = params[106]
            lv959: R.Tensor((22016, 128), dtype="float16") = params[107]
            lv86 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv958, lv959, lv545), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv961 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv86,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv962: R.Tensor((4096, 1376), dtype="uint32") = params[108]
            lv963: R.Tensor((4096, 344), dtype="float16") = params[109]
            lv85_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv962, lv963, lv961, lv84_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv159_1: R.Tensor((4096,), dtype="float16") = params[120]
            lv556 = R.call_tir(cls.rms_norm, (lv85_1, lv159_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv966: R.Tensor((12288, 512), dtype="uint32") = params[112]
            lv967: R.Tensor((12288, 128), dtype="float16") = params[113]
            lv87 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv966, lv967, lv556), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv559 = R.call_tir(cls.split1, (lv87,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv560: R.Tensor((1, n, 4096), dtype="float16") = lv559[0]
            lv561 = R.call_tir(cls.reshape7, (lv560,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv562: R.Tensor((1, n, 4096), dtype="float16") = lv559[1]
            lv563 = R.call_tir(cls.reshape7, (lv562,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv564: R.Tensor((1, n, 4096), dtype="float16") = lv559[2]
            lv565 = R.call_tir(cls.reshape7, (lv564,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv566 = R.call_tir(cls.rotary_embedding, (lv561, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv567 = R.call_tir(cls.rotary_embedding, (lv563, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv568 = R.call_tir(cls.squeeze1, (lv567,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv569 = R.call_tir(cls.squeeze1, (lv565,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv570: R.Object = kv_cache[22]
            lv571: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv570, lv568, sinfo_args=(R.Object,))
            lv572: R.Object = kv_cache[23]
            lv573: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv572, lv569, sinfo_args=(R.Object,))
            lv574: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv571, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv575: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv573, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv576 = R.call_tir(cls.reshape3, (lv574,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv577 = R.call_tir(cls.reshape3, (lv575,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv578 = R.call_tir(cls.transpose6, (lv566,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv579 = R.call_tir(cls.transpose6, (lv576,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv580 = R.call_tir(cls.transpose6, (lv577,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv969 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv578, lv579, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv970 = R.call_tir(cls.fused_softmax1_cast4, (lv969,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv589 = R.call_tir(cls.matmul10, (lv970, lv580), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv590 = R.call_tir(cls.transpose8, (lv589,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv591 = R.call_tir(cls.reshape8, (lv590,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv971: R.Tensor((4096, 512), dtype="uint32") = params[114]
            lv972: R.Tensor((4096, 128), dtype="float16") = params[115]
            lv86_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv971, lv972, lv591, lv85_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv166_1: R.Tensor((4096,), dtype="float16") = params[121]
            lv595 = R.call_tir(cls.rms_norm, (lv86_1, lv166_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv975: R.Tensor((22016, 512), dtype="uint32") = params[116]
            lv976: R.Tensor((22016, 128), dtype="float16") = params[117]
            lv88 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv975, lv976, lv595), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv978 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv88,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv979: R.Tensor((4096, 1376), dtype="uint32") = params[118]
            lv980: R.Tensor((4096, 344), dtype="float16") = params[119]
            lv87_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv979, lv980, lv978, lv86_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv173_1: R.Tensor((4096,), dtype="float16") = params[130]
            lv606 = R.call_tir(cls.rms_norm, (lv87_1, lv173_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv983: R.Tensor((12288, 512), dtype="uint32") = params[122]
            lv984: R.Tensor((12288, 128), dtype="float16") = params[123]
            lv89_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv983, lv984, lv606), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv609 = R.call_tir(cls.split1, (lv89_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv610: R.Tensor((1, n, 4096), dtype="float16") = lv609[0]
            lv611 = R.call_tir(cls.reshape7, (lv610,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv612: R.Tensor((1, n, 4096), dtype="float16") = lv609[1]
            lv613 = R.call_tir(cls.reshape7, (lv612,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv614: R.Tensor((1, n, 4096), dtype="float16") = lv609[2]
            lv615 = R.call_tir(cls.reshape7, (lv614,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv616 = R.call_tir(cls.rotary_embedding, (lv611, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv617 = R.call_tir(cls.rotary_embedding, (lv613, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv618 = R.call_tir(cls.squeeze1, (lv617,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv619 = R.call_tir(cls.squeeze1, (lv615,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv620: R.Object = kv_cache[24]
            lv621: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv620, lv618, sinfo_args=(R.Object,))
            lv622: R.Object = kv_cache[25]
            lv623: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv622, lv619, sinfo_args=(R.Object,))
            lv624: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv621, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv625: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv623, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv626 = R.call_tir(cls.reshape3, (lv624,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv627 = R.call_tir(cls.reshape3, (lv625,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv628 = R.call_tir(cls.transpose6, (lv616,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv629 = R.call_tir(cls.transpose6, (lv626,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv630 = R.call_tir(cls.transpose6, (lv627,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv986 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv628, lv629, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv987 = R.call_tir(cls.fused_softmax1_cast4, (lv986,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv639 = R.call_tir(cls.matmul10, (lv987, lv630), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv640 = R.call_tir(cls.transpose8, (lv639,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv641 = R.call_tir(cls.reshape8, (lv640,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv988: R.Tensor((4096, 512), dtype="uint32") = params[124]
            lv989: R.Tensor((4096, 128), dtype="float16") = params[125]
            lv88_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv988, lv989, lv641, lv87_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv180_1: R.Tensor((4096,), dtype="float16") = params[131]
            lv645 = R.call_tir(cls.rms_norm, (lv88_1, lv180_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv992: R.Tensor((22016, 512), dtype="uint32") = params[126]
            lv993: R.Tensor((22016, 128), dtype="float16") = params[127]
            lv90_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv992, lv993, lv645), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv995 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv90_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv996: R.Tensor((4096, 1376), dtype="uint32") = params[128]
            lv997: R.Tensor((4096, 344), dtype="float16") = params[129]
            lv89_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv996, lv997, lv995, lv88_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv187: R.Tensor((4096,), dtype="float16") = params[140]
            lv656 = R.call_tir(cls.rms_norm, (lv89_3, lv187), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1000: R.Tensor((12288, 512), dtype="uint32") = params[132]
            lv1001: R.Tensor((12288, 128), dtype="float16") = params[133]
            lv91_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1000, lv1001, lv656), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv659 = R.call_tir(cls.split1, (lv91_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv660: R.Tensor((1, n, 4096), dtype="float16") = lv659[0]
            lv661 = R.call_tir(cls.reshape7, (lv660,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv662: R.Tensor((1, n, 4096), dtype="float16") = lv659[1]
            lv663 = R.call_tir(cls.reshape7, (lv662,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv664: R.Tensor((1, n, 4096), dtype="float16") = lv659[2]
            lv665 = R.call_tir(cls.reshape7, (lv664,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv666 = R.call_tir(cls.rotary_embedding, (lv661, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv667 = R.call_tir(cls.rotary_embedding, (lv663, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv668 = R.call_tir(cls.squeeze1, (lv667,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv669 = R.call_tir(cls.squeeze1, (lv665,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv670: R.Object = kv_cache[26]
            lv671: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv670, lv668, sinfo_args=(R.Object,))
            lv672: R.Object = kv_cache[27]
            lv673: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv672, lv669, sinfo_args=(R.Object,))
            lv674: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv671, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv675: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv673, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv676 = R.call_tir(cls.reshape3, (lv674,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv677 = R.call_tir(cls.reshape3, (lv675,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv678 = R.call_tir(cls.transpose6, (lv666,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv679 = R.call_tir(cls.transpose6, (lv676,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv680 = R.call_tir(cls.transpose6, (lv677,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1003 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv678, lv679, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1004 = R.call_tir(cls.fused_softmax1_cast4, (lv1003,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv689 = R.call_tir(cls.matmul10, (lv1004, lv680), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv690 = R.call_tir(cls.transpose8, (lv689,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv691 = R.call_tir(cls.reshape8, (lv690,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1005: R.Tensor((4096, 512), dtype="uint32") = params[134]
            lv1006: R.Tensor((4096, 128), dtype="float16") = params[135]
            lv90_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1005, lv1006, lv691, lv89_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv194: R.Tensor((4096,), dtype="float16") = params[141]
            lv695 = R.call_tir(cls.rms_norm, (lv90_2, lv194), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1009: R.Tensor((22016, 512), dtype="uint32") = params[136]
            lv1010: R.Tensor((22016, 128), dtype="float16") = params[137]
            lv92 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1009, lv1010, lv695), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1012 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv92,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1013: R.Tensor((4096, 1376), dtype="uint32") = params[138]
            lv1014: R.Tensor((4096, 344), dtype="float16") = params[139]
            lv91_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1013, lv1014, lv1012, lv90_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv201: R.Tensor((4096,), dtype="float16") = params[150]
            lv706 = R.call_tir(cls.rms_norm, (lv91_2, lv201), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1017: R.Tensor((12288, 512), dtype="uint32") = params[142]
            lv1018: R.Tensor((12288, 128), dtype="float16") = params[143]
            lv93 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1017, lv1018, lv706), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv709 = R.call_tir(cls.split1, (lv93,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv710: R.Tensor((1, n, 4096), dtype="float16") = lv709[0]
            lv711 = R.call_tir(cls.reshape7, (lv710,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv712: R.Tensor((1, n, 4096), dtype="float16") = lv709[1]
            lv713 = R.call_tir(cls.reshape7, (lv712,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv714: R.Tensor((1, n, 4096), dtype="float16") = lv709[2]
            lv715 = R.call_tir(cls.reshape7, (lv714,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv716 = R.call_tir(cls.rotary_embedding, (lv711, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv717 = R.call_tir(cls.rotary_embedding, (lv713, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv718 = R.call_tir(cls.squeeze1, (lv717,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv719 = R.call_tir(cls.squeeze1, (lv715,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv720: R.Object = kv_cache[28]
            lv721: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv720, lv718, sinfo_args=(R.Object,))
            lv722: R.Object = kv_cache[29]
            lv723: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv722, lv719, sinfo_args=(R.Object,))
            lv724: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv721, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv725: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv723, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv726 = R.call_tir(cls.reshape3, (lv724,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv727 = R.call_tir(cls.reshape3, (lv725,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv728 = R.call_tir(cls.transpose6, (lv716,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv729 = R.call_tir(cls.transpose6, (lv726,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv730 = R.call_tir(cls.transpose6, (lv727,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1020 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv728, lv729, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1021 = R.call_tir(cls.fused_softmax1_cast4, (lv1020,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv739 = R.call_tir(cls.matmul10, (lv1021, lv730), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv740 = R.call_tir(cls.transpose8, (lv739,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv741 = R.call_tir(cls.reshape8, (lv740,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1022: R.Tensor((4096, 512), dtype="uint32") = params[144]
            lv1023: R.Tensor((4096, 128), dtype="float16") = params[145]
            lv92_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1022, lv1023, lv741, lv91_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv208: R.Tensor((4096,), dtype="float16") = params[151]
            lv745 = R.call_tir(cls.rms_norm, (lv92_1, lv208), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1026: R.Tensor((22016, 512), dtype="uint32") = params[146]
            lv1027: R.Tensor((22016, 128), dtype="float16") = params[147]
            lv94 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1026, lv1027, lv745), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1029 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv94,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1030: R.Tensor((4096, 1376), dtype="uint32") = params[148]
            lv1031: R.Tensor((4096, 344), dtype="float16") = params[149]
            lv93_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1030, lv1031, lv1029, lv92_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv215_1: R.Tensor((4096,), dtype="float16") = params[160]
            lv756 = R.call_tir(cls.rms_norm, (lv93_1, lv215_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1034: R.Tensor((12288, 512), dtype="uint32") = params[152]
            lv1035: R.Tensor((12288, 128), dtype="float16") = params[153]
            lv95_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1034, lv1035, lv756), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv759 = R.call_tir(cls.split1, (lv95_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv760: R.Tensor((1, n, 4096), dtype="float16") = lv759[0]
            lv761 = R.call_tir(cls.reshape7, (lv760,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv762: R.Tensor((1, n, 4096), dtype="float16") = lv759[1]
            lv763 = R.call_tir(cls.reshape7, (lv762,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv764: R.Tensor((1, n, 4096), dtype="float16") = lv759[2]
            lv765 = R.call_tir(cls.reshape7, (lv764,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv766 = R.call_tir(cls.rotary_embedding, (lv761, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv767 = R.call_tir(cls.rotary_embedding, (lv763, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv768 = R.call_tir(cls.squeeze1, (lv767,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv769 = R.call_tir(cls.squeeze1, (lv765,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv770: R.Object = kv_cache[30]
            lv771: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv770, lv768, sinfo_args=(R.Object,))
            lv772: R.Object = kv_cache[31]
            lv773: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv772, lv769, sinfo_args=(R.Object,))
            lv774: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv771, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv775_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv773, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv776_1 = R.call_tir(cls.reshape3, (lv774,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv777 = R.call_tir(cls.reshape3, (lv775_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv778_1 = R.call_tir(cls.transpose6, (lv766,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv779_1 = R.call_tir(cls.transpose6, (lv776_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv780_1 = R.call_tir(cls.transpose6, (lv777,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1037 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv778_1, lv779_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1038 = R.call_tir(cls.fused_softmax1_cast4, (lv1037,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv789_1 = R.call_tir(cls.matmul10, (lv1038, lv780_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv790 = R.call_tir(cls.transpose8, (lv789_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv791_1 = R.call_tir(cls.reshape8, (lv790,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1039: R.Tensor((4096, 512), dtype="uint32") = params[154]
            lv1040: R.Tensor((4096, 128), dtype="float16") = params[155]
            lv94_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1039, lv1040, lv791_1, lv93_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv222_1: R.Tensor((4096,), dtype="float16") = params[161]
            lv795 = R.call_tir(cls.rms_norm, (lv94_1, lv222_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1043: R.Tensor((22016, 512), dtype="uint32") = params[156]
            lv1044: R.Tensor((22016, 128), dtype="float16") = params[157]
            lv96_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1043, lv1044, lv795), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1046 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv96_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1047: R.Tensor((4096, 1376), dtype="uint32") = params[158]
            lv1048: R.Tensor((4096, 344), dtype="float16") = params[159]
            lv95_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1047, lv1048, lv1046, lv94_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv229_1: R.Tensor((4096,), dtype="float16") = params[170]
            lv806_1 = R.call_tir(cls.rms_norm, (lv95_2, lv229_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1051: R.Tensor((12288, 512), dtype="uint32") = params[162]
            lv1052: R.Tensor((12288, 128), dtype="float16") = params[163]
            lv97 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1051, lv1052, lv806_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv809_1 = R.call_tir(cls.split1, (lv97,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv810_1: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[0]
            lv811 = R.call_tir(cls.reshape7, (lv810_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv812: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[1]
            lv813_1 = R.call_tir(cls.reshape7, (lv812,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv814_1: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[2]
            lv815 = R.call_tir(cls.reshape7, (lv814_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv816_1 = R.call_tir(cls.rotary_embedding, (lv811, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv817_1 = R.call_tir(cls.rotary_embedding, (lv813_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv818_1 = R.call_tir(cls.squeeze1, (lv817_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv819_1 = R.call_tir(cls.squeeze1, (lv815,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv820: R.Object = kv_cache[32]
            lv821: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv820, lv818_1, sinfo_args=(R.Object,))
            lv822_1: R.Object = kv_cache[33]
            lv823_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv822_1, lv819_1, sinfo_args=(R.Object,))
            lv824: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv821, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv825_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv823_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv826_1 = R.call_tir(cls.reshape3, (lv824,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv827_1 = R.call_tir(cls.reshape3, (lv825_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv828 = R.call_tir(cls.transpose6, (lv816_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv829 = R.call_tir(cls.transpose6, (lv826_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv830_1 = R.call_tir(cls.transpose6, (lv827_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1054 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv828, lv829, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1055 = R.call_tir(cls.fused_softmax1_cast4, (lv1054,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv839_1 = R.call_tir(cls.matmul10, (lv1055, lv830_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv840_1 = R.call_tir(cls.transpose8, (lv839_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv841 = R.call_tir(cls.reshape8, (lv840_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1056: R.Tensor((4096, 512), dtype="uint32") = params[164]
            lv1057: R.Tensor((4096, 128), dtype="float16") = params[165]
            lv96_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1056, lv1057, lv841, lv95_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv236: R.Tensor((4096,), dtype="float16") = params[171]
            lv845 = R.call_tir(cls.rms_norm, (lv96_2, lv236), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1060: R.Tensor((22016, 512), dtype="uint32") = params[166]
            lv1061: R.Tensor((22016, 128), dtype="float16") = params[167]
            lv98 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1060, lv1061, lv845), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1063 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv98,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1064: R.Tensor((4096, 1376), dtype="uint32") = params[168]
            lv1065: R.Tensor((4096, 344), dtype="float16") = params[169]
            lv97_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1064, lv1065, lv1063, lv96_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv243: R.Tensor((4096,), dtype="float16") = params[180]
            lv856_1 = R.call_tir(cls.rms_norm, (lv97_1, lv243), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1068: R.Tensor((12288, 512), dtype="uint32") = params[172]
            lv1069: R.Tensor((12288, 128), dtype="float16") = params[173]
            lv99 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1068, lv1069, lv856_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv859_1 = R.call_tir(cls.split1, (lv99,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv860_1: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[0]
            lv861_1 = R.call_tir(cls.reshape7, (lv860_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv862: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[1]
            lv863 = R.call_tir(cls.reshape7, (lv862,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv864_1: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[2]
            lv865_1 = R.call_tir(cls.reshape7, (lv864_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv866 = R.call_tir(cls.rotary_embedding, (lv861_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv867_1 = R.call_tir(cls.rotary_embedding, (lv863, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv868_1 = R.call_tir(cls.squeeze1, (lv867_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv869_1 = R.call_tir(cls.squeeze1, (lv865_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv870_1: R.Object = kv_cache[34]
            lv871: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv870_1, lv868_1, sinfo_args=(R.Object,))
            lv872: R.Object = kv_cache[35]
            lv873_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv872, lv869_1, sinfo_args=(R.Object,))
            lv874_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv871, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv875: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv873_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv876_1 = R.call_tir(cls.reshape3, (lv874_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv877_1 = R.call_tir(cls.reshape3, (lv875,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv878_1 = R.call_tir(cls.transpose6, (lv866,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv879 = R.call_tir(cls.transpose6, (lv876_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv880 = R.call_tir(cls.transpose6, (lv877_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1071 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv878_1, lv879, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1072 = R.call_tir(cls.fused_softmax1_cast4, (lv1071,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv889 = R.call_tir(cls.matmul10, (lv1072, lv880), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv890_1 = R.call_tir(cls.transpose8, (lv889,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv891_1 = R.call_tir(cls.reshape8, (lv890_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1073: R.Tensor((4096, 512), dtype="uint32") = params[174]
            lv1074: R.Tensor((4096, 128), dtype="float16") = params[175]
            lv98_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1073, lv1074, lv891_1, lv97_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv250: R.Tensor((4096,), dtype="float16") = params[181]
            lv895_1 = R.call_tir(cls.rms_norm, (lv98_1, lv250), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1077: R.Tensor((22016, 512), dtype="uint32") = params[176]
            lv1078: R.Tensor((22016, 128), dtype="float16") = params[177]
            lv100 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1077, lv1078, lv895_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1080 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv100,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1081: R.Tensor((4096, 1376), dtype="uint32") = params[178]
            lv1082: R.Tensor((4096, 344), dtype="float16") = params[179]
            lv99_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1081, lv1082, lv1080, lv98_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv257: R.Tensor((4096,), dtype="float16") = params[190]
            lv906 = R.call_tir(cls.rms_norm, (lv99_1, lv257), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1085: R.Tensor((12288, 512), dtype="uint32") = params[182]
            lv1086: R.Tensor((12288, 128), dtype="float16") = params[183]
            lv101 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1085, lv1086, lv906), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv909 = R.call_tir(cls.split1, (lv101,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv910_1: R.Tensor((1, n, 4096), dtype="float16") = lv909[0]
            lv911_1 = R.call_tir(cls.reshape7, (lv910_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv912_1: R.Tensor((1, n, 4096), dtype="float16") = lv909[1]
            lv913 = R.call_tir(cls.reshape7, (lv912_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv914: R.Tensor((1, n, 4096), dtype="float16") = lv909[2]
            lv915_1 = R.call_tir(cls.reshape7, (lv914,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv916_1 = R.call_tir(cls.rotary_embedding, (lv911_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv917 = R.call_tir(cls.rotary_embedding, (lv913, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv918_1 = R.call_tir(cls.squeeze1, (lv917,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv919_1 = R.call_tir(cls.squeeze1, (lv915_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv920_1: R.Object = kv_cache[36]
            lv921_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv920_1, lv918_1, sinfo_args=(R.Object,))
            lv922: R.Object = kv_cache[37]
            lv923: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv922, lv919_1, sinfo_args=(R.Object,))
            lv924_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv921_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv925_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv923, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv926 = R.call_tir(cls.reshape3, (lv924_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv927_1 = R.call_tir(cls.reshape3, (lv925_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv928_1 = R.call_tir(cls.transpose6, (lv916_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv929_1 = R.call_tir(cls.transpose6, (lv926,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv930 = R.call_tir(cls.transpose6, (lv927_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1088 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv928_1, lv929_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1089 = R.call_tir(cls.fused_softmax1_cast4, (lv1088,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv939 = R.call_tir(cls.matmul10, (lv1089, lv930), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv940 = R.call_tir(cls.transpose8, (lv939,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv941_1 = R.call_tir(cls.reshape8, (lv940,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1090: R.Tensor((4096, 512), dtype="uint32") = params[184]
            lv1091: R.Tensor((4096, 128), dtype="float16") = params[185]
            lv100_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1090, lv1091, lv941_1, lv99_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv264_1: R.Tensor((4096,), dtype="float16") = params[191]
            lv945_1 = R.call_tir(cls.rms_norm, (lv100_1, lv264_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1094: R.Tensor((22016, 512), dtype="uint32") = params[186]
            lv1095: R.Tensor((22016, 128), dtype="float16") = params[187]
            lv102 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1094, lv1095, lv945_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1097 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv102,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1098: R.Tensor((4096, 1376), dtype="uint32") = params[188]
            lv1099: R.Tensor((4096, 344), dtype="float16") = params[189]
            lv101_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1098, lv1099, lv1097, lv100_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv271_1: R.Tensor((4096,), dtype="float16") = params[200]
            lv956 = R.call_tir(cls.rms_norm, (lv101_1, lv271_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1102: R.Tensor((12288, 512), dtype="uint32") = params[192]
            lv1103: R.Tensor((12288, 128), dtype="float16") = params[193]
            lv103_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1102, lv1103, lv956), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv959_1 = R.call_tir(cls.split1, (lv103_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv960: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[0]
            lv961_1 = R.call_tir(cls.reshape7, (lv960,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv962_1: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[1]
            lv963_1 = R.call_tir(cls.reshape7, (lv962_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv964: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[2]
            lv965 = R.call_tir(cls.reshape7, (lv964,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv966_1 = R.call_tir(cls.rotary_embedding, (lv961_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv967_1 = R.call_tir(cls.rotary_embedding, (lv963_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv968 = R.call_tir(cls.squeeze1, (lv967_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv969_1 = R.call_tir(cls.squeeze1, (lv965,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv970_1: R.Object = kv_cache[38]
            lv971_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv970_1, lv968, sinfo_args=(R.Object,))
            lv972_1: R.Object = kv_cache[39]
            lv973: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv972_1, lv969_1, sinfo_args=(R.Object,))
            lv974: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv971_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv975_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv973, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv976_1 = R.call_tir(cls.reshape3, (lv974,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv977 = R.call_tir(cls.reshape3, (lv975_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv978_1 = R.call_tir(cls.transpose6, (lv966_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv979_1 = R.call_tir(cls.transpose6, (lv976_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv980_1 = R.call_tir(cls.transpose6, (lv977,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1105 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv978_1, lv979_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1106 = R.call_tir(cls.fused_softmax1_cast4, (lv1105,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv989_1 = R.call_tir(cls.matmul10, (lv1106, lv980_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv990 = R.call_tir(cls.transpose8, (lv989_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv991 = R.call_tir(cls.reshape8, (lv990,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1107: R.Tensor((4096, 512), dtype="uint32") = params[194]
            lv1108: R.Tensor((4096, 128), dtype="float16") = params[195]
            lv102_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1107, lv1108, lv991, lv101_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv278_1: R.Tensor((4096,), dtype="float16") = params[201]
            lv995_1 = R.call_tir(cls.rms_norm, (lv102_1, lv278_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1111: R.Tensor((22016, 512), dtype="uint32") = params[196]
            lv1112: R.Tensor((22016, 128), dtype="float16") = params[197]
            lv104 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1111, lv1112, lv995_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1114 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv104,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1115: R.Tensor((4096, 1376), dtype="uint32") = params[198]
            lv1116: R.Tensor((4096, 344), dtype="float16") = params[199]
            lv103_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1115, lv1116, lv1114, lv102_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv285: R.Tensor((4096,), dtype="float16") = params[210]
            lv1006_1 = R.call_tir(cls.rms_norm, (lv103_2, lv285), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1119: R.Tensor((12288, 512), dtype="uint32") = params[202]
            lv1120: R.Tensor((12288, 128), dtype="float16") = params[203]
            lv105 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1119, lv1120, lv1006_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1009_1 = R.call_tir(cls.split1, (lv105,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1010_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[0]
            lv1011 = R.call_tir(cls.reshape7, (lv1010_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1012_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[1]
            lv1013_1 = R.call_tir(cls.reshape7, (lv1012_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1014_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[2]
            lv1015 = R.call_tir(cls.reshape7, (lv1014_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1016 = R.call_tir(cls.rotary_embedding, (lv1011, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1017_1 = R.call_tir(cls.rotary_embedding, (lv1013_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1018_1 = R.call_tir(cls.squeeze1, (lv1017_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1019 = R.call_tir(cls.squeeze1, (lv1015,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1020_1: R.Object = kv_cache[40]
            lv1021_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1020_1, lv1018_1, sinfo_args=(R.Object,))
            lv1022_1: R.Object = kv_cache[41]
            lv1023_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1022_1, lv1019, sinfo_args=(R.Object,))
            lv1024: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1021_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1025: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1023_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1026_1 = R.call_tir(cls.reshape3, (lv1024,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1027_1 = R.call_tir(cls.reshape3, (lv1025,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1028 = R.call_tir(cls.transpose6, (lv1016,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1029_1 = R.call_tir(cls.transpose6, (lv1026_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1030_1 = R.call_tir(cls.transpose6, (lv1027_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1122 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1028, lv1029_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1123 = R.call_tir(cls.fused_softmax1_cast4, (lv1122,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1039_1 = R.call_tir(cls.matmul10, (lv1123, lv1030_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1040_1 = R.call_tir(cls.transpose8, (lv1039_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1041 = R.call_tir(cls.reshape8, (lv1040_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1124: R.Tensor((4096, 512), dtype="uint32") = params[204]
            lv1125: R.Tensor((4096, 128), dtype="float16") = params[205]
            lv104_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1124, lv1125, lv1041, lv103_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv292: R.Tensor((4096,), dtype="float16") = params[211]
            lv1045 = R.call_tir(cls.rms_norm, (lv104_1, lv292), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1128: R.Tensor((22016, 512), dtype="uint32") = params[206]
            lv1129: R.Tensor((22016, 128), dtype="float16") = params[207]
            lv106_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1128, lv1129, lv1045), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1131 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv106_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1132: R.Tensor((4096, 1376), dtype="uint32") = params[208]
            lv1133: R.Tensor((4096, 344), dtype="float16") = params[209]
            lv105_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1132, lv1133, lv1131, lv104_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv299: R.Tensor((4096,), dtype="float16") = params[220]
            lv1056_1 = R.call_tir(cls.rms_norm, (lv105_1, lv299), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1136: R.Tensor((12288, 512), dtype="uint32") = params[212]
            lv1137: R.Tensor((12288, 128), dtype="float16") = params[213]
            lv107 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1136, lv1137, lv1056_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1059 = R.call_tir(cls.split1, (lv107,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1060_1: R.Tensor((1, n, 4096), dtype="float16") = lv1059[0]
            lv1061_1 = R.call_tir(cls.reshape7, (lv1060_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1062: R.Tensor((1, n, 4096), dtype="float16") = lv1059[1]
            lv1063_1 = R.call_tir(cls.reshape7, (lv1062,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1064_1: R.Tensor((1, n, 4096), dtype="float16") = lv1059[2]
            lv1065_1 = R.call_tir(cls.reshape7, (lv1064_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1066 = R.call_tir(cls.rotary_embedding, (lv1061_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1067 = R.call_tir(cls.rotary_embedding, (lv1063_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1068_1 = R.call_tir(cls.squeeze1, (lv1067,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1069_1 = R.call_tir(cls.squeeze1, (lv1065_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1070: R.Object = kv_cache[42]
            lv1071_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1070, lv1068_1, sinfo_args=(R.Object,))
            lv1072_1: R.Object = kv_cache[43]
            lv1073_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1072_1, lv1069_1, sinfo_args=(R.Object,))
            lv1074_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1071_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1075: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1073_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1076 = R.call_tir(cls.reshape3, (lv1074_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1077_1 = R.call_tir(cls.reshape3, (lv1075,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1078_1 = R.call_tir(cls.transpose6, (lv1066,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1079 = R.call_tir(cls.transpose6, (lv1076,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1080_1 = R.call_tir(cls.transpose6, (lv1077_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1139 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1078_1, lv1079, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1140 = R.call_tir(cls.fused_softmax1_cast4, (lv1139,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1089_1 = R.call_tir(cls.matmul10, (lv1140, lv1080_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1090_1 = R.call_tir(cls.transpose8, (lv1089_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1091_1 = R.call_tir(cls.reshape8, (lv1090_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1141: R.Tensor((4096, 512), dtype="uint32") = params[214]
            lv1142: R.Tensor((4096, 128), dtype="float16") = params[215]
            lv106_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1141, lv1142, lv1091_1, lv105_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv306_1: R.Tensor((4096,), dtype="float16") = params[221]
            lv1095_1 = R.call_tir(cls.rms_norm, (lv106_2, lv306_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1145: R.Tensor((22016, 512), dtype="uint32") = params[216]
            lv1146: R.Tensor((22016, 128), dtype="float16") = params[217]
            lv108 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1145, lv1146, lv1095_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1148 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv108,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1149: R.Tensor((4096, 1376), dtype="uint32") = params[218]
            lv1150: R.Tensor((4096, 344), dtype="float16") = params[219]
            lv107_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1149, lv1150, lv1148, lv106_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv313_1: R.Tensor((4096,), dtype="float16") = params[230]
            lv1106_1 = R.call_tir(cls.rms_norm, (lv107_1, lv313_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1153: R.Tensor((12288, 512), dtype="uint32") = params[222]
            lv1154: R.Tensor((12288, 128), dtype="float16") = params[223]
            lv109_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1153, lv1154, lv1106_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1109 = R.call_tir(cls.split1, (lv109_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1110: R.Tensor((1, n, 4096), dtype="float16") = lv1109[0]
            lv1111_1 = R.call_tir(cls.reshape7, (lv1110,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1112_1: R.Tensor((1, n, 4096), dtype="float16") = lv1109[1]
            lv1113 = R.call_tir(cls.reshape7, (lv1112_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1114_1: R.Tensor((1, n, 4096), dtype="float16") = lv1109[2]
            lv1115_1 = R.call_tir(cls.reshape7, (lv1114_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1116_1 = R.call_tir(cls.rotary_embedding, (lv1111_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1117 = R.call_tir(cls.rotary_embedding, (lv1113, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1118 = R.call_tir(cls.squeeze1, (lv1117,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1119_1 = R.call_tir(cls.squeeze1, (lv1115_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1120_1: R.Object = kv_cache[44]
            lv1121: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1120_1, lv1118, sinfo_args=(R.Object,))
            lv1122_1: R.Object = kv_cache[45]
            lv1123_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1122_1, lv1119_1, sinfo_args=(R.Object,))
            lv1124_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1121, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1125_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1123_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1126 = R.call_tir(cls.reshape3, (lv1124_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1127 = R.call_tir(cls.reshape3, (lv1125_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1128_1 = R.call_tir(cls.transpose6, (lv1116_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1129_1 = R.call_tir(cls.transpose6, (lv1126,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1130 = R.call_tir(cls.transpose6, (lv1127,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1156 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1128_1, lv1129_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1157 = R.call_tir(cls.fused_softmax1_cast4, (lv1156,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1139_1 = R.call_tir(cls.matmul10, (lv1157, lv1130), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1140_1 = R.call_tir(cls.transpose8, (lv1139_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1141_1 = R.call_tir(cls.reshape8, (lv1140_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1158: R.Tensor((4096, 512), dtype="uint32") = params[224]
            lv1159: R.Tensor((4096, 128), dtype="float16") = params[225]
            lv108_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1158, lv1159, lv1141_1, lv107_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv320_1: R.Tensor((4096,), dtype="float16") = params[231]
            lv1145_1 = R.call_tir(cls.rms_norm, (lv108_1, lv320_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1162: R.Tensor((22016, 512), dtype="uint32") = params[226]
            lv1163: R.Tensor((22016, 128), dtype="float16") = params[227]
            lv110_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1162, lv1163, lv1145_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1165 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv110_2,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1166: R.Tensor((4096, 1376), dtype="uint32") = params[228]
            lv1167: R.Tensor((4096, 344), dtype="float16") = params[229]
            lv109_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1166, lv1167, lv1165, lv108_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv327_1: R.Tensor((4096,), dtype="float16") = params[240]
            lv1156_1 = R.call_tir(cls.rms_norm, (lv109_2, lv327_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1170: R.Tensor((12288, 512), dtype="uint32") = params[232]
            lv1171: R.Tensor((12288, 128), dtype="float16") = params[233]
            lv111_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1170, lv1171, lv1156_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1159_1 = R.call_tir(cls.split1, (lv111_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1160: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[0]
            lv1161 = R.call_tir(cls.reshape7, (lv1160,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1162_1: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[1]
            lv1163_1 = R.call_tir(cls.reshape7, (lv1162_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1164: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[2]
            lv1165_1 = R.call_tir(cls.reshape7, (lv1164,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1166_1 = R.call_tir(cls.rotary_embedding, (lv1161, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1167_1 = R.call_tir(cls.rotary_embedding, (lv1163_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1168 = R.call_tir(cls.squeeze1, (lv1167_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1169 = R.call_tir(cls.squeeze1, (lv1165_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1170_1: R.Object = kv_cache[46]
            lv1171_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1170_1, lv1168, sinfo_args=(R.Object,))
            lv1172: R.Object = kv_cache[47]
            lv1173: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1172, lv1169, sinfo_args=(R.Object,))
            lv1174: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1171_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1175: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1173, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1176 = R.call_tir(cls.reshape3, (lv1174,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1177 = R.call_tir(cls.reshape3, (lv1175,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1178 = R.call_tir(cls.transpose6, (lv1166_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1179 = R.call_tir(cls.transpose6, (lv1176,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1180 = R.call_tir(cls.transpose6, (lv1177,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1173_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1178, lv1179, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1174_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1173_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1189 = R.call_tir(cls.matmul10, (lv1174_1, lv1180), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1190 = R.call_tir(cls.transpose8, (lv1189,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1191 = R.call_tir(cls.reshape8, (lv1190,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1175_1: R.Tensor((4096, 512), dtype="uint32") = params[234]
            lv1176_1: R.Tensor((4096, 128), dtype="float16") = params[235]
            lv110_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1175_1, lv1176_1, lv1191, lv109_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv334: R.Tensor((4096,), dtype="float16") = params[241]
            lv1195 = R.call_tir(cls.rms_norm, (lv110_3, lv334), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1179_1: R.Tensor((22016, 512), dtype="uint32") = params[236]
            lv1180_1: R.Tensor((22016, 128), dtype="float16") = params[237]
            lv112_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1179_1, lv1180_1, lv1195), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1182 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv112_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1183: R.Tensor((4096, 1376), dtype="uint32") = params[238]
            lv1184: R.Tensor((4096, 344), dtype="float16") = params[239]
            lv111_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1183, lv1184, lv1182, lv110_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv341_1: R.Tensor((4096,), dtype="float16") = params[250]
            lv1206 = R.call_tir(cls.rms_norm, (lv111_2, lv341_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1187: R.Tensor((12288, 512), dtype="uint32") = params[242]
            lv1188: R.Tensor((12288, 128), dtype="float16") = params[243]
            lv113_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1187, lv1188, lv1206), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1209 = R.call_tir(cls.split1, (lv113_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1210: R.Tensor((1, n, 4096), dtype="float16") = lv1209[0]
            lv1211 = R.call_tir(cls.reshape7, (lv1210,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1212: R.Tensor((1, n, 4096), dtype="float16") = lv1209[1]
            lv1213 = R.call_tir(cls.reshape7, (lv1212,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1214: R.Tensor((1, n, 4096), dtype="float16") = lv1209[2]
            lv1215 = R.call_tir(cls.reshape7, (lv1214,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1216 = R.call_tir(cls.rotary_embedding, (lv1211, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1217 = R.call_tir(cls.rotary_embedding, (lv1213, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1218 = R.call_tir(cls.squeeze1, (lv1217,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1219 = R.call_tir(cls.squeeze1, (lv1215,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1220: R.Object = kv_cache[48]
            lv1221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1220, lv1218, sinfo_args=(R.Object,))
            lv1222: R.Object = kv_cache[49]
            lv1223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1222, lv1219, sinfo_args=(R.Object,))
            lv1224: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1221, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1225: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1223, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1226 = R.call_tir(cls.reshape3, (lv1224,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1227 = R.call_tir(cls.reshape3, (lv1225,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1228 = R.call_tir(cls.transpose6, (lv1216,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1229 = R.call_tir(cls.transpose6, (lv1226,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1230 = R.call_tir(cls.transpose6, (lv1227,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1190_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1228, lv1229, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1191_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1190_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1239 = R.call_tir(cls.matmul10, (lv1191_1, lv1230), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1240 = R.call_tir(cls.transpose8, (lv1239,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1241 = R.call_tir(cls.reshape8, (lv1240,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1192: R.Tensor((4096, 512), dtype="uint32") = params[244]
            lv1193: R.Tensor((4096, 128), dtype="float16") = params[245]
            lv112_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1192, lv1193, lv1241, lv111_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv348: R.Tensor((4096,), dtype="float16") = params[251]
            lv1245 = R.call_tir(cls.rms_norm, (lv112_2, lv348), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1196: R.Tensor((22016, 512), dtype="uint32") = params[246]
            lv1197: R.Tensor((22016, 128), dtype="float16") = params[247]
            lv114_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1196, lv1197, lv1245), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1199 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv114_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1200: R.Tensor((4096, 1376), dtype="uint32") = params[248]
            lv1201: R.Tensor((4096, 344), dtype="float16") = params[249]
            lv113_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1200, lv1201, lv1199, lv112_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv355: R.Tensor((4096,), dtype="float16") = params[260]
            lv1256 = R.call_tir(cls.rms_norm, (lv113_2, lv355), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1204: R.Tensor((12288, 512), dtype="uint32") = params[252]
            lv1205: R.Tensor((12288, 128), dtype="float16") = params[253]
            lv115_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1204, lv1205, lv1256), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1259 = R.call_tir(cls.split1, (lv115_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1260: R.Tensor((1, n, 4096), dtype="float16") = lv1259[0]
            lv1261 = R.call_tir(cls.reshape7, (lv1260,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1262: R.Tensor((1, n, 4096), dtype="float16") = lv1259[1]
            lv1263 = R.call_tir(cls.reshape7, (lv1262,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1264: R.Tensor((1, n, 4096), dtype="float16") = lv1259[2]
            lv1265 = R.call_tir(cls.reshape7, (lv1264,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1266 = R.call_tir(cls.rotary_embedding, (lv1261, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1267 = R.call_tir(cls.rotary_embedding, (lv1263, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1268 = R.call_tir(cls.squeeze1, (lv1267,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1269 = R.call_tir(cls.squeeze1, (lv1265,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1270: R.Object = kv_cache[50]
            lv1271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1270, lv1268, sinfo_args=(R.Object,))
            lv1272: R.Object = kv_cache[51]
            lv1273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1272, lv1269, sinfo_args=(R.Object,))
            lv1274: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1271, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1275: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1273, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1276 = R.call_tir(cls.reshape3, (lv1274,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1277 = R.call_tir(cls.reshape3, (lv1275,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1278 = R.call_tir(cls.transpose6, (lv1266,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1279 = R.call_tir(cls.transpose6, (lv1276,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1280 = R.call_tir(cls.transpose6, (lv1277,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1207 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1278, lv1279, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1208 = R.call_tir(cls.fused_softmax1_cast4, (lv1207,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1289 = R.call_tir(cls.matmul10, (lv1208, lv1280), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1290 = R.call_tir(cls.transpose8, (lv1289,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1291 = R.call_tir(cls.reshape8, (lv1290,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1209_1: R.Tensor((4096, 512), dtype="uint32") = params[254]
            lv1210_1: R.Tensor((4096, 128), dtype="float16") = params[255]
            lv114_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1209_1, lv1210_1, lv1291, lv113_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv362_1: R.Tensor((4096,), dtype="float16") = params[261]
            lv1295 = R.call_tir(cls.rms_norm, (lv114_2, lv362_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1213_1: R.Tensor((22016, 512), dtype="uint32") = params[256]
            lv1214_1: R.Tensor((22016, 128), dtype="float16") = params[257]
            lv116_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1213_1, lv1214_1, lv1295), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1216_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv116_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1217_1: R.Tensor((4096, 1376), dtype="uint32") = params[258]
            lv1218_1: R.Tensor((4096, 344), dtype="float16") = params[259]
            lv115_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1217_1, lv1218_1, lv1216_1, lv114_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv369_1: R.Tensor((4096,), dtype="float16") = params[270]
            lv1306 = R.call_tir(cls.rms_norm, (lv115_2, lv369_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1221_1: R.Tensor((12288, 512), dtype="uint32") = params[262]
            lv1222_1: R.Tensor((12288, 128), dtype="float16") = params[263]
            lv117_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1221_1, lv1222_1, lv1306), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1309 = R.call_tir(cls.split1, (lv117_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1310: R.Tensor((1, n, 4096), dtype="float16") = lv1309[0]
            lv1311 = R.call_tir(cls.reshape7, (lv1310,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1312: R.Tensor((1, n, 4096), dtype="float16") = lv1309[1]
            lv1313 = R.call_tir(cls.reshape7, (lv1312,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1314: R.Tensor((1, n, 4096), dtype="float16") = lv1309[2]
            lv1315 = R.call_tir(cls.reshape7, (lv1314,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1316 = R.call_tir(cls.rotary_embedding, (lv1311, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1317 = R.call_tir(cls.rotary_embedding, (lv1313, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1318 = R.call_tir(cls.squeeze1, (lv1317,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1319 = R.call_tir(cls.squeeze1, (lv1315,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1320: R.Object = kv_cache[52]
            lv1321: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1320, lv1318, sinfo_args=(R.Object,))
            lv1322: R.Object = kv_cache[53]
            lv1323: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1322, lv1319, sinfo_args=(R.Object,))
            lv1324: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1321, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1325: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1323, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1326 = R.call_tir(cls.reshape3, (lv1324,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1327 = R.call_tir(cls.reshape3, (lv1325,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1328 = R.call_tir(cls.transpose6, (lv1316,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1329 = R.call_tir(cls.transpose6, (lv1326,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1330 = R.call_tir(cls.transpose6, (lv1327,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1224_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1328, lv1329, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1225_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1224_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1339 = R.call_tir(cls.matmul10, (lv1225_1, lv1330), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1340 = R.call_tir(cls.transpose8, (lv1339,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1341 = R.call_tir(cls.reshape8, (lv1340,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1226_1: R.Tensor((4096, 512), dtype="uint32") = params[264]
            lv1227_1: R.Tensor((4096, 128), dtype="float16") = params[265]
            lv116_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1226_1, lv1227_1, lv1341, lv115_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv376_1: R.Tensor((4096,), dtype="float16") = params[271]
            lv1345 = R.call_tir(cls.rms_norm, (lv116_2, lv376_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1230_1: R.Tensor((22016, 512), dtype="uint32") = params[266]
            lv1231: R.Tensor((22016, 128), dtype="float16") = params[267]
            lv118_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1230_1, lv1231, lv1345), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1233 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv118_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1234: R.Tensor((4096, 1376), dtype="uint32") = params[268]
            lv1235: R.Tensor((4096, 344), dtype="float16") = params[269]
            lv117_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1234, lv1235, lv1233, lv116_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv383: R.Tensor((4096,), dtype="float16") = params[280]
            lv1356 = R.call_tir(cls.rms_norm, (lv117_3, lv383), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1238: R.Tensor((12288, 512), dtype="uint32") = params[272]
            lv1239_1: R.Tensor((12288, 128), dtype="float16") = params[273]
            lv119_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1238, lv1239_1, lv1356), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1359 = R.call_tir(cls.split1, (lv119_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1360: R.Tensor((1, n, 4096), dtype="float16") = lv1359[0]
            lv1361 = R.call_tir(cls.reshape7, (lv1360,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1362: R.Tensor((1, n, 4096), dtype="float16") = lv1359[1]
            lv1363 = R.call_tir(cls.reshape7, (lv1362,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1364: R.Tensor((1, n, 4096), dtype="float16") = lv1359[2]
            lv1365 = R.call_tir(cls.reshape7, (lv1364,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1366 = R.call_tir(cls.rotary_embedding, (lv1361, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1367 = R.call_tir(cls.rotary_embedding, (lv1363, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1368 = R.call_tir(cls.squeeze1, (lv1367,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1369 = R.call_tir(cls.squeeze1, (lv1365,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1370: R.Object = kv_cache[54]
            lv1371: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1370, lv1368, sinfo_args=(R.Object,))
            lv1372: R.Object = kv_cache[55]
            lv1373: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1372, lv1369, sinfo_args=(R.Object,))
            lv1374: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1371, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1375: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1373, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1376 = R.call_tir(cls.reshape3, (lv1374,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1377 = R.call_tir(cls.reshape3, (lv1375,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1378 = R.call_tir(cls.transpose6, (lv1366,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1379 = R.call_tir(cls.transpose6, (lv1376,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1380 = R.call_tir(cls.transpose6, (lv1377,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1241_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1378, lv1379, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1242 = R.call_tir(cls.fused_softmax1_cast4, (lv1241_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1389 = R.call_tir(cls.matmul10, (lv1242, lv1380), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1390 = R.call_tir(cls.transpose8, (lv1389,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1391 = R.call_tir(cls.reshape8, (lv1390,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1243: R.Tensor((4096, 512), dtype="uint32") = params[274]
            lv1244: R.Tensor((4096, 128), dtype="float16") = params[275]
            lv118_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1243, lv1244, lv1391, lv117_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv390_1: R.Tensor((4096,), dtype="float16") = params[281]
            lv1395 = R.call_tir(cls.rms_norm, (lv118_2, lv390_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1247: R.Tensor((22016, 512), dtype="uint32") = params[276]
            lv1248: R.Tensor((22016, 128), dtype="float16") = params[277]
            lv120_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1247, lv1248, lv1395), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1250 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv120_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1251: R.Tensor((4096, 1376), dtype="uint32") = params[278]
            lv1252: R.Tensor((4096, 344), dtype="float16") = params[279]
            lv119_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1251, lv1252, lv1250, lv118_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv397: R.Tensor((4096,), dtype="float16") = params[290]
            lv1406 = R.call_tir(cls.rms_norm, (lv119_2, lv397), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1255: R.Tensor((12288, 512), dtype="uint32") = params[282]
            lv1256_1: R.Tensor((12288, 128), dtype="float16") = params[283]
            lv121_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1255, lv1256_1, lv1406), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1409 = R.call_tir(cls.split1, (lv121_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1410: R.Tensor((1, n, 4096), dtype="float16") = lv1409[0]
            lv1411 = R.call_tir(cls.reshape7, (lv1410,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1412: R.Tensor((1, n, 4096), dtype="float16") = lv1409[1]
            lv1413 = R.call_tir(cls.reshape7, (lv1412,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1414: R.Tensor((1, n, 4096), dtype="float16") = lv1409[2]
            lv1415 = R.call_tir(cls.reshape7, (lv1414,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1416 = R.call_tir(cls.rotary_embedding, (lv1411, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1417 = R.call_tir(cls.rotary_embedding, (lv1413, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1418 = R.call_tir(cls.squeeze1, (lv1417,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1419 = R.call_tir(cls.squeeze1, (lv1415,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1420: R.Object = kv_cache[56]
            lv1421: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1420, lv1418, sinfo_args=(R.Object,))
            lv1422: R.Object = kv_cache[57]
            lv1423: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1422, lv1419, sinfo_args=(R.Object,))
            lv1424: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1421, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1425: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1423, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1426 = R.call_tir(cls.reshape3, (lv1424,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1427 = R.call_tir(cls.reshape3, (lv1425,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1428 = R.call_tir(cls.transpose6, (lv1416,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1429 = R.call_tir(cls.transpose6, (lv1426,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1430 = R.call_tir(cls.transpose6, (lv1427,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1258 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1428, lv1429, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1259_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1258,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1439 = R.call_tir(cls.matmul10, (lv1259_1, lv1430), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1440 = R.call_tir(cls.transpose8, (lv1439,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1441 = R.call_tir(cls.reshape8, (lv1440,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1260_1: R.Tensor((4096, 512), dtype="uint32") = params[284]
            lv1261_1: R.Tensor((4096, 128), dtype="float16") = params[285]
            lv120_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1260_1, lv1261_1, lv1441, lv119_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv404: R.Tensor((4096,), dtype="float16") = params[291]
            lv1445 = R.call_tir(cls.rms_norm, (lv120_2, lv404), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1264_1: R.Tensor((22016, 512), dtype="uint32") = params[286]
            lv1265_1: R.Tensor((22016, 128), dtype="float16") = params[287]
            lv122_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1264_1, lv1265_1, lv1445), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1267_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv122_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1268_1: R.Tensor((4096, 1376), dtype="uint32") = params[288]
            lv1269_1: R.Tensor((4096, 344), dtype="float16") = params[289]
            lv121_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1268_1, lv1269_1, lv1267_1, lv120_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv411_1: R.Tensor((4096,), dtype="float16") = params[300]
            lv1456 = R.call_tir(cls.rms_norm, (lv121_2, lv411_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1272_1: R.Tensor((12288, 512), dtype="uint32") = params[292]
            lv1273_1: R.Tensor((12288, 128), dtype="float16") = params[293]
            lv123_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1272_1, lv1273_1, lv1456), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1459 = R.call_tir(cls.split1, (lv123_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1460: R.Tensor((1, n, 4096), dtype="float16") = lv1459[0]
            lv1461 = R.call_tir(cls.reshape7, (lv1460,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1462: R.Tensor((1, n, 4096), dtype="float16") = lv1459[1]
            lv1463 = R.call_tir(cls.reshape7, (lv1462,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1464: R.Tensor((1, n, 4096), dtype="float16") = lv1459[2]
            lv1465 = R.call_tir(cls.reshape7, (lv1464,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1466 = R.call_tir(cls.rotary_embedding, (lv1461, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1467 = R.call_tir(cls.rotary_embedding, (lv1463, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1468 = R.call_tir(cls.squeeze1, (lv1467,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1469 = R.call_tir(cls.squeeze1, (lv1465,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1470: R.Object = kv_cache[58]
            lv1471: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1470, lv1468, sinfo_args=(R.Object,))
            lv1472: R.Object = kv_cache[59]
            lv1473: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1472, lv1469, sinfo_args=(R.Object,))
            lv1474: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1471, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1475: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1473, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1476 = R.call_tir(cls.reshape3, (lv1474,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1477 = R.call_tir(cls.reshape3, (lv1475,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1478 = R.call_tir(cls.transpose6, (lv1466,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1479 = R.call_tir(cls.transpose6, (lv1476,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1480 = R.call_tir(cls.transpose6, (lv1477,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1275_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1478, lv1479, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1276_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1275_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1489 = R.call_tir(cls.matmul10, (lv1276_1, lv1480), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1490 = R.call_tir(cls.transpose8, (lv1489,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1491 = R.call_tir(cls.reshape8, (lv1490,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1277_1: R.Tensor((4096, 512), dtype="uint32") = params[294]
            lv1278_1: R.Tensor((4096, 128), dtype="float16") = params[295]
            lv122_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1277_1, lv1278_1, lv1491, lv121_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv418_1: R.Tensor((4096,), dtype="float16") = params[301]
            lv1495 = R.call_tir(cls.rms_norm, (lv122_2, lv418_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1281: R.Tensor((22016, 512), dtype="uint32") = params[296]
            lv1282: R.Tensor((22016, 128), dtype="float16") = params[297]
            lv124_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1281, lv1282, lv1495), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1284 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv124_2,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1285: R.Tensor((4096, 1376), dtype="uint32") = params[298]
            lv1286: R.Tensor((4096, 344), dtype="float16") = params[299]
            lv123_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1285, lv1286, lv1284, lv122_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv425_1: R.Tensor((4096,), dtype="float16") = params[310]
            lv1506 = R.call_tir(cls.rms_norm, (lv123_2, lv425_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1289_1: R.Tensor((12288, 512), dtype="uint32") = params[302]
            lv1290_1: R.Tensor((12288, 128), dtype="float16") = params[303]
            lv125_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1289_1, lv1290_1, lv1506), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1509 = R.call_tir(cls.split1, (lv125_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1510: R.Tensor((1, n, 4096), dtype="float16") = lv1509[0]
            lv1511 = R.call_tir(cls.reshape7, (lv1510,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1512: R.Tensor((1, n, 4096), dtype="float16") = lv1509[1]
            lv1513 = R.call_tir(cls.reshape7, (lv1512,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1514: R.Tensor((1, n, 4096), dtype="float16") = lv1509[2]
            lv1515 = R.call_tir(cls.reshape7, (lv1514,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1516 = R.call_tir(cls.rotary_embedding, (lv1511, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1517 = R.call_tir(cls.rotary_embedding, (lv1513, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1518 = R.call_tir(cls.squeeze1, (lv1517,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1519 = R.call_tir(cls.squeeze1, (lv1515,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1520: R.Object = kv_cache[60]
            lv1521: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1520, lv1518, sinfo_args=(R.Object,))
            lv1522: R.Object = kv_cache[61]
            lv1523: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1522, lv1519, sinfo_args=(R.Object,))
            lv1524: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1521, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1525: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1523, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1526 = R.call_tir(cls.reshape3, (lv1524,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1527 = R.call_tir(cls.reshape3, (lv1525,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1528 = R.call_tir(cls.transpose6, (lv1516,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1529 = R.call_tir(cls.transpose6, (lv1526,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1530 = R.call_tir(cls.transpose6, (lv1527,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1292 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1528, lv1529, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1293 = R.call_tir(cls.fused_softmax1_cast4, (lv1292,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1539 = R.call_tir(cls.matmul10, (lv1293, lv1530), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1540 = R.call_tir(cls.transpose8, (lv1539,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1541 = R.call_tir(cls.reshape8, (lv1540,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1294: R.Tensor((4096, 512), dtype="uint32") = params[304]
            lv1295_1: R.Tensor((4096, 128), dtype="float16") = params[305]
            lv124_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1294, lv1295_1, lv1541, lv123_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv432: R.Tensor((4096,), dtype="float16") = params[311]
            lv1545 = R.call_tir(cls.rms_norm, (lv124_3, lv432), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1298: R.Tensor((22016, 512), dtype="uint32") = params[306]
            lv1299: R.Tensor((22016, 128), dtype="float16") = params[307]
            lv126_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1298, lv1299, lv1545), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1301 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv126_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1302: R.Tensor((4096, 1376), dtype="uint32") = params[308]
            lv1303: R.Tensor((4096, 344), dtype="float16") = params[309]
            lv125_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1302, lv1303, lv1301, lv124_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv439_1: R.Tensor((4096,), dtype="float16") = params[320]
            lv1556 = R.call_tir(cls.rms_norm, (lv125_2, lv439_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1306_1: R.Tensor((12288, 512), dtype="uint32") = params[312]
            lv1307: R.Tensor((12288, 128), dtype="float16") = params[313]
            lv127_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1306_1, lv1307, lv1556), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
            lv1559 = R.call_tir(cls.split1, (lv127_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
            lv1560: R.Tensor((1, n, 4096), dtype="float16") = lv1559[0]
            lv1561 = R.call_tir(cls.reshape7, (lv1560,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1562: R.Tensor((1, n, 4096), dtype="float16") = lv1559[1]
            lv1563 = R.call_tir(cls.reshape7, (lv1562,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1564: R.Tensor((1, n, 4096), dtype="float16") = lv1559[2]
            lv1565 = R.call_tir(cls.reshape7, (lv1564,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1566 = R.call_tir(cls.rotary_embedding, (lv1561, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1567 = R.call_tir(cls.rotary_embedding, (lv1563, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
            lv1568 = R.call_tir(cls.squeeze1, (lv1567,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1569 = R.call_tir(cls.squeeze1, (lv1565,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
            lv1570: R.Object = kv_cache[62]
            lv1571: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1570, lv1568, sinfo_args=(R.Object,))
            lv1572: R.Object = kv_cache[63]
            lv1573: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1572, lv1569, sinfo_args=(R.Object,))
            lv1574: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1571, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1575: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1573, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
            lv1576 = R.call_tir(cls.reshape3, (lv1574,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1577 = R.call_tir(cls.reshape3, (lv1575,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
            lv1578 = R.call_tir(cls.transpose6, (lv1566,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1579 = R.call_tir(cls.transpose6, (lv1576,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1580 = R.call_tir(cls.transpose6, (lv1577,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
            lv1309_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1578, lv1579, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
            lv1310_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1309_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
            lv1589 = R.call_tir(cls.matmul10, (lv1310_1, lv1580), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
            lv1590 = R.call_tir(cls.transpose8, (lv1589,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
            lv1591 = R.call_tir(cls.reshape8, (lv1590,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1311_1: R.Tensor((4096, 512), dtype="uint32") = params[314]
            lv1312_1: R.Tensor((4096, 128), dtype="float16") = params[315]
            lv126_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1311_1, lv1312_1, lv1591, lv125_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv446: R.Tensor((4096,), dtype="float16") = params[321]
            lv1595 = R.call_tir(cls.rms_norm, (lv126_2, lv446), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1315_1: R.Tensor((22016, 512), dtype="uint32") = params[316]
            lv1316_1: R.Tensor((22016, 128), dtype="float16") = params[317]
            lv128_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1315_1, lv1316_1, lv1595), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
            lv1318_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv128_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
            lv1319_1: R.Tensor((4096, 1376), dtype="uint32") = params[318]
            lv1320_1: R.Tensor((4096, 344), dtype="float16") = params[319]
            lv127_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1319_1, lv1320_1, lv1318_1, lv126_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv453: R.Tensor((4096,), dtype="float16") = params[322]
            lv1606 = R.call_tir(cls.rms_norm, (lv127_2, lv453), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
            lv1607 = R.call_tir(cls.slice, (lv1606,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
            lv1323_1: R.Tensor((32001, 512), dtype="uint32") = params[323]
            lv1324_1: R.Tensor((32001, 128), dtype="float16") = params[324]
            lv129_1 = R.call_tir(cls.fused_fused_decode1_fused_NT_matmul5_cast2, (lv1323_1, lv1324_1, lv1607), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32"))
            gv: R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)) = lv129_1, (lv21, lv23, lv71, lv73, lv121, lv123, lv171, lv173, lv221, lv223, lv271, lv273, lv321, lv323, lv371, lv373, lv421, lv423, lv471, lv473, lv521, lv523, lv571, lv573, lv621, lv623, lv671, lv673, lv721, lv723, lv771, lv773, lv821, lv823_1, lv871, lv873_1, lv921_1, lv923, lv971_1, lv973, lv1021_1, lv1023_1, lv1071_1, lv1073_1, lv1121, lv1123_1, lv1171_1, lv1173, lv1221, lv1223, lv1271, lv1273, lv1321, lv1323, lv1371, lv1373, lv1421, lv1423, lv1471, lv1473, lv1521, lv1523, lv1571, lv1573)
            R.output(gv)
        return gv

    @R.function
    def softmax_with_temperature(logits: R.Tensor((1, 1, 32001), dtype="float32"), temperature: R.Tensor((), dtype="float32")) -> R.Tensor((1, 1, 32001), dtype="float32"):
        R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
        cls = Module
        with R.dataflow():
            lv3285 = R.call_tir(cls.divide2, (logits, temperature), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32"))
            lv3286 = R.call_tir(cls.softmax2, (lv3285,), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32"))
            gv3: R.Tensor((1, 1, 32001), dtype="float32") = lv3286
            R.output(gv3)
        return gv3