Untitled

 avatar
unknown
plain_text
2 months ago
1.8 kB
3
Indexable
import numpy as np 
import torch
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda"
sam = sam_model_registry [model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: ['area']), reverse=True)
    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.3511]])
        img[m] = color_mask
    return img


def generate_mask(pil_image):
    image=np.array(pil_image)
    image = cv2.resize(image, None, fx=0.5, fy=0.5)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    print(len(masks))
    print(masks[0].keys())
    print(masks[0])

    result_image = show_anns(masks)
    cv2.imwrite('result.jpg', (result_image * 255).astype(np.uint8))

    return masks

'''
image = cv2.imread('images/road.png')
image = cv2.resize(image, None, fx=0.5, fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
print(masks[0])

result_image = show_anns(masks)
cv2.imwrite('result.jpg', (result_image * 255).astype(np.uint8))
'''


def get_image_embedding(pil_image):
    predictor = SamPredictor(sam)
    
    image=np.array(pil_image)

    predictor.set_image(image)
    image_embedding = predictor.get_image_embedding()
    
    return image_embedding

Leave a Comment