Untitled
unknown
python
a year ago
1.5 kB
16
Indexable
class DINOv2(nn.Module):
def __init__(self, model_args, device) -> None:
super().__init__()
self.model_size = model_args["model_size"]
if self.model_size == "small":
self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
if self.model_size == "small_reg":
self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg")
if self.model_size == "base":
self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
if self.model_size == "base_reg":
self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg")
self.layer_num = model_args["layer"]
self.feat_extr.eval()
self.feat_extr.to(device)
self.device = device
self.patch_size = 14
pass
def get_features(self, imgs):
# layer = self.layer_num[0] # TODO: make it a list
with torch.no_grad():
if self.layer_num == "last":
out = self.feat_extr.forward_features(imgs)
patch = out["x_norm_patchtokens"]
cls = out["x_norm_clstoken"]
elif self.layer_num == "first":
patch, cls = self.feat_extr.get_intermediate_layers(imgs, return_class_token=True)[0]
elif self.layer_num == "avg":
pass
out = {
"cls": cls,
"patch": patch
}
return outEditor is loading...
Leave a Comment