Untitled
unknown
plain_text
10 months ago
1.8 kB
1
Indexable
# import some common libraries import numpy as np import os, json, cv2, random import detectron2 from detectron2.utils.logger import setup_logger setup_logger() # import some common detectron2 utilities from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer, GenericMask, _PanopticPrediction from detectron2.data import MetadataCatalog, DatasetCatalog from PIL import Image import matplotlib.pyplot as plt class SkyMaskExtractor: def __init__(self): self.cfg = get_cfg() self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")) self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml") self.predictor = DefaultPredictor(self.cfg) def run(self, im: np.ndarray): """ Extract sky mask on image. Input: im, np.ndarray(np.uint8): Input image. Returns: mask or None: sky mask if sky is present. Otherwise None. """ panoptic_seg, segments_info = self.predictor(im)["panoptic_seg"] pred = _PanopticPrediction(panoptic_seg.to("cpu"), segments_info, MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0])) semantic_masks = pred.semantic_masks() for mask, sinfo in semantic_masks: if sinfo['category_id'] == 40: print(mask.shape) mask = mask.astype(float) return mask return None if __name__ == "__main__": im = cv2.imread(img_name) extractor = SkyMaskExtractor() mask = extractor.run(im) plt.imshow(mask) plt.show()
Editor is loading...
Leave a Comment