Untitled
unknown
plain_text
a year ago
1.8 kB
4
Indexable
# import some common libraries
import numpy as np
import os, json, cv2, random
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, GenericMask, _PanopticPrediction
from detectron2.data import MetadataCatalog, DatasetCatalog
from PIL import Image
import matplotlib.pyplot as plt
class SkyMaskExtractor:
def __init__(self):
self.cfg = get_cfg()
self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
self.predictor = DefaultPredictor(self.cfg)
def run(self, im: np.ndarray):
"""
Extract sky mask on image.
Input:
im, np.ndarray(np.uint8): Input image.
Returns:
mask or None: sky mask if sky is present. Otherwise None.
"""
panoptic_seg, segments_info = self.predictor(im)["panoptic_seg"]
pred = _PanopticPrediction(panoptic_seg.to("cpu"), segments_info, MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]))
semantic_masks = pred.semantic_masks()
for mask, sinfo in semantic_masks:
if sinfo['category_id'] == 40:
print(mask.shape)
mask = mask.astype(float)
return mask
return None
if __name__ == "__main__":
im = cv2.imread(img_name)
extractor = SkyMaskExtractor()
mask = extractor.run(im)
plt.imshow(mask)
plt.show()
Editor is loading...
Leave a Comment