Untitled
import numpy as np import torch import cv2 from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor sam_checkpoint = "sam_vit_b_01ec64.pth" model_type = "vit_b" device = "cuda" sam = sam_model_registry [model_type](checkpoint=sam_checkpoint) sam.to(device=device) def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: ['area']), reverse=True) img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) img[:,:, 3] = 0 for ann in sorted_anns: m = ann['segmentation'] color_mask = np.concatenate([np.random.random(3), [0.3511]]) img[m] = color_mask return img def generate_mask(pil_image): image=np.array(pil_image) image = cv2.resize(image, None, fx=0.5, fy=0.5) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image) print(len(masks)) print(masks[0].keys()) print(masks[0]) result_image = show_anns(masks) cv2.imwrite('result.jpg', (result_image * 255).astype(np.uint8)) return masks ''' image = cv2.imread('images/road.png') image = cv2.resize(image, None, fx=0.5, fy=0.5) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image) print(len(masks)) print(masks[0].keys()) print(masks[0]) result_image = show_anns(masks) cv2.imwrite('result.jpg', (result_image * 255).astype(np.uint8)) ''' def get_image_embedding(pil_image): predictor = SamPredictor(sam) image=np.array(pil_image) predictor.set_image(image) image_embedding = predictor.get_image_embedding() return image_embedding
Leave a Comment