Untitled
unknown
plain_text
a year ago
1.9 kB
9
Indexable
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
suffix_list = ["gate_proj", "up_proj", "down_proj"]
for suffix in suffix_list:
module.register_buffer(f'{suffix}_use', torch.Tensor([False]))
u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.u"), map_location=torch.device(infer_device()))
s = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.s"), map_location=torch.device(infer_device()))
v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.v"), map_location=torch.device(infer_device()))
scaling_matrix_inv = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.scaling_matrix_inv"), map_location=torch.device(infer_device()))
# 读取当前层的 desired rank
k = desired_rank_pref[f'{layer_idx}'][suffix][0]
u, v = get_uv(scaling_matrix_inv, u, s, v, k)
print('get u v: ', name, suffix, k, u.shape, v.shape, u.device, v.device)
if suffix == "gate_proj":
module.register_buffer('gate_weight_U_top', v.t().to(torch.bfloat16))
module.register_buffer('gate_weight_SVh_top', u.t().to(torch.bfloat16))
elif suffix == "up_proj":
module.register_buffer('up_weight_U_top', v.t().to(torch.bfloat16))
module.register_buffer('up_weight_SVh_top', u.t().to(torch.bfloat16))
else:
module.register_buffer('down_weight_U_top', v.t().to(torch.bfloat16))
module.register_buffer('down_weight_SVh_top', u.t().to(torch.bfloat16))
del u
del s
del v
del scaling_matrix_inv
utils.clear_torch_cache()
del module.gate_proj
del module.up_proj
del module.down_proj
utils.clear_torch_cache()Editor is loading...
Leave a Comment