Untitled
unknown
plain_text
10 months ago
1.8 kB
5
Indexable
import numpy as np
import torch
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda"
sam = sam_model_registry [model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: ['area']), reverse=True)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:, 3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.3511]])
img[m] = color_mask
return img
def generate_mask(pil_image):
image=np.array(pil_image)
image = cv2.resize(image, None, fx=0.5, fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
print(masks[0])
result_image = show_anns(masks)
cv2.imwrite('result.jpg', (result_image * 255).astype(np.uint8))
return masks
'''
image = cv2.imread('images/road.png')
image = cv2.resize(image, None, fx=0.5, fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
print(masks[0])
result_image = show_anns(masks)
cv2.imwrite('result.jpg', (result_image * 255).astype(np.uint8))
'''
def get_image_embedding(pil_image):
predictor = SamPredictor(sam)
image=np.array(pil_image)
predictor.set_image(image)
image_embedding = predictor.get_image_embedding()
return image_embedding
Editor is loading...
Leave a Comment