Untitled

 avatar
unknown
plain_text
10 months ago
1.8 kB
1
Indexable

# import some common libraries
import numpy as np
import os, json, cv2, random
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, GenericMask, _PanopticPrediction
from detectron2.data import MetadataCatalog, DatasetCatalog
from PIL import Image
import matplotlib.pyplot as plt


class SkyMaskExtractor:

    def __init__(self):
        self.cfg = get_cfg()
        self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
        self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
        self.predictor = DefaultPredictor(self.cfg)

    def run(self, im: np.ndarray):
        """
            Extract sky mask on image.
            Input:
                im, np.ndarray(np.uint8): Input image.
            Returns:
                mask or None: sky mask if sky is present. Otherwise None.
        """
        panoptic_seg, segments_info = self.predictor(im)["panoptic_seg"]
        pred = _PanopticPrediction(panoptic_seg.to("cpu"), segments_info, MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]))
        semantic_masks = pred.semantic_masks()    
        for mask, sinfo in semantic_masks:
            if sinfo['category_id'] == 40:
                print(mask.shape)
                mask = mask.astype(float)
                return mask
        return None

if __name__ == "__main__":
    im = cv2.imread(img_name)
    extractor = SkyMaskExtractor()
    mask = extractor.run(im)
    plt.imshow(mask)
    plt.show()
Editor is loading...
Leave a Comment