Untitled

mail@pastecode.io avatar
unknown
python
2 months ago
2.1 kB
2
Indexable
Never
for fname in tqdm.tqdm(image_list):
    features = all_data[fname]['features'].cuda()
    with torch.no_grad():
        cls_head.eval()
        # few shot
        example_features = all_data[fname]['example_clip_features'].cuda()
       
    min_scores = 0.05
    max_points = 1000
    pred_points_score = all_data[fname]['predictions']['pred_points_score']
    mask = torch.zeros(pred_points_score.size(0))
    mask[:min(pred_points_score.size(0), max_points)] = 1
    mask[pred_points_score < min_scores] = 0
    pred_boxes = all_data[fname]['predictions']['pred_boxes'][mask.bool()].cuda()
    pred_ious = all_data[fname]['predictions']['pred_ious'][mask.bool()].cuda()
   
    all_pred_scores = []
    for indices in torch.arange(len(pred_boxes)).split(128):
        with torch.no_grad():
            cls_outs_ = cls_head(all_data[fname]['features'].cuda(), [pred_boxes[indices], ], [example_features, ] * len(indices))
            pred_logits = cls_outs_.sigmoid().view(-1, len(example_features), 5).mean(1)
       
            pred_logits = pred_logits * pred_ious[indices]

            all_pred_boxes.append(pred_boxes[indices, torch.argmax(pred_logits, dim=1)])
            all_pred_scores.append(pred_logits.max(dim=1).values)

    height, width = all_data[fname]['height'], all_data[fname]['width']
    scale = max(height, width) / 1024.
    pred_boxes = torch.cat(all_pred_boxes) * scale
    pred_boxes[:, [0, 2]] = pred_boxes[:, [0, 2]].clamp(0, width)
    pred_boxes[:, [1, 3]] = pred_boxes[:, [1, 3]].clamp(0, height)
    pred_scores = torch.cat(all_pred_scores)
    box_area = vision_ops.box_area(pred_boxes)
    mask = (box_area < (height * width * 0.75)) & (box_area > 10)
    pred_boxes = pred_boxes[mask]
    pred_scores = pred_scores[mask]
   
    nms_indices = vision_ops.nms(pred_boxes, pred_scores, 0.5)
    instances = Instances((height, width))
    pred_boxes = pred_boxes[nms_indices]
    pred_scores = pred_scores[nms_indices]
    instances.pred_boxes = Boxes(pred_boxes)
    instances.scores = pred_scores
    instances.pred_classes = torch.zeros(len(pred_boxes)).cuda().long()
   
Leave a Comment