Untitled
unknown
python
5 months ago
1.5 kB
5
Indexable
class DINOv2(nn.Module): def __init__(self, model_args, device) -> None: super().__init__() self.model_size = model_args["model_size"] if self.model_size == "small": self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14") if self.model_size == "small_reg": self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg") if self.model_size == "base": self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14") if self.model_size == "base_reg": self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg") self.layer_num = model_args["layer"] self.feat_extr.eval() self.feat_extr.to(device) self.device = device self.patch_size = 14 pass def get_features(self, imgs): # layer = self.layer_num[0] # TODO: make it a list with torch.no_grad(): if self.layer_num == "last": out = self.feat_extr.forward_features(imgs) patch = out["x_norm_patchtokens"] cls = out["x_norm_clstoken"] elif self.layer_num == "first": patch, cls = self.feat_extr.get_intermediate_layers(imgs, return_class_token=True)[0] elif self.layer_num == "avg": pass out = { "cls": cls, "patch": patch } return out
Editor is loading...
Leave a Comment