Untitled

 avatar
unknown
python
5 months ago
1.5 kB
5
Indexable
class DINOv2(nn.Module):

    def __init__(self, model_args, device) -> None:
        super().__init__()
        self.model_size = model_args["model_size"]
        if self.model_size == "small":
            self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
        if self.model_size == "small_reg":
            self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg")
        if self.model_size == "base":
            self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
        if self.model_size == "base_reg":
            self.feat_extr = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg")
        self.layer_num = model_args["layer"]
        self.feat_extr.eval()
        self.feat_extr.to(device)
        self.device = device
        self.patch_size = 14
        pass

    def get_features(self, imgs):
        # layer = self.layer_num[0] # TODO: make it a list
        with torch.no_grad():
            if self.layer_num == "last":
                out = self.feat_extr.forward_features(imgs) 
                patch = out["x_norm_patchtokens"]
                cls = out["x_norm_clstoken"]          
            elif self.layer_num == "first":
                patch, cls = self.feat_extr.get_intermediate_layers(imgs, return_class_token=True)[0]
            elif self.layer_num == "avg":
                pass
        out = {
            "cls": cls,
            "patch": patch
        }
        return out
Editor is loading...
Leave a Comment