Untitled

 avatar
unknown
plain_text
a year ago
4.0 kB
3
Indexable
def get_uv(scaling_matrix_inv, u, s, v, k):
    if v.device != scaling_matrix_inv.device:
        scaling_matrix_inv = scaling_matrix_inv.to(v.device)
    v = v @ scaling_matrix_inv
    svd_u = u[:, :k]
    svd_s = s[:k]
    svd_v = v[:k, :]
    sqrt_s = torch.diag(torch.sqrt(svd_s))
    if svd_u.device != sqrt_s.device:
        print('svd u s device: ', svd_u.device, sqrt_s.device)
        svd_u = svd_u.to(sqrt_s.device)
    if sqrt_s.device != svd_v.device:
        print('svd s v device: ', sqrt_s.device, svd_v.device)
        svd_v = svd_v.to(sqrt_s.device)
    u=(svd_u @ sqrt_s).T
    v=(sqrt_s @ svd_v).T
    return u, v

for name, module in model.named_modules():
    if isinstance(module, LlamaMLP):
        del module.gate_proj
        del module.up_proj
        del module.down_proj
        utils.clear_torch_cache()
        suffix_list = ["gate_proj", "up_proj", "down_proj"]
        for suffix in suffix_list:
            module.register_buffer(f'{suffix}_use', torch.Tensor([False]))
            u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.u"), map_location=torch.device(infer_device()))
            s = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.s"), map_location=torch.device(infer_device()))
            v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.v"), map_location=torch.device(infer_device()))
            scaling_matrix_inv = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.scaling_matrix_inv"), map_location=torch.device(infer_device()))
            # 读取当前层的 desired rank
            k = desired_rank_pref[f'{layer_idx}'][suffix][0]
            u, v = get_uv(scaling_matrix_inv, u, s, v, k)
            print('get u v: ', name, suffix, k, u.shape, v.shape, u.device, v.device)
            if suffix == "gate_proj":
                module.register_buffer('gate_weight_U_top', v.t().to(torch.bfloat16))
                module.register_buffer('gate_weight_SVh_top', u.t().to(torch.bfloat16))
            elif suffix == "up_proj":
                module.register_buffer('up_weight_U_top', v.t().to(torch.bfloat16))
                module.register_buffer('up_weight_SVh_top', u.t().to(torch.bfloat16))
            else:
                module.register_buffer('down_weight_U_top', v.t().to(torch.bfloat16))
                module.register_buffer('down_weight_SVh_top', u.t().to(torch.bfloat16))
            del u
            del s
            del v
            del scaling_matrix_inv
            utils.clear_torch_cache()
    if isinstance(module, LlamaAttention):
        suffix_list = ["q_proj", "k_proj"]
        for suffix in suffix_list:
            u = torch.load(os.path.join(dump_dest_attn, f"{name}.{suffix}.u"), map_location=torch.device(infer_device()))
            s = torch.load(os.path.join(dump_dest_attn, f"{name}.{suffix}.s"), map_location=torch.device(infer_device()))
            v = torch.load(os.path.join(dump_dest_attn, f"{name}.{suffix}.v"), map_location=torch.device(infer_device()))
            scaling_matrix_inv = torch.load(os.path.join(dump_dest_attn, f"{name}.{suffix}.scaling_matrix_inv"), map_location=torch.device(infer_device()))
            # 读取当前层的 desired rank
            k = desired_rank_pref[f'{layer_idx}'][suffix][0]
            u, v = get_uv(scaling_matrix_inv, u, s, v, k)
            print('attn get u v: ', name, suffix, k, u.shape, v.shape, u.device, v.device)
            if suffix == "q_proj":
                module.register_buffer('q_weight_U_top', v.t().to(torch.bfloat16))
                module.register_buffer('q_weight_SVh_top', u.t().to(torch.bfloat16))
            else:
                module.register_buffer('k_weight_U_top', v.t().to(torch.bfloat16))
                module.register_buffer('k_weight_SVh_top', u.t().to(torch.bfloat16))
            del u
            del s
            del v
            del scaling_matrix_inv
            utils.clear_torch_cache()
Editor is loading...
Leave a Comment