Untitled

 avatar
unknown
python
a year ago
2.5 kB
5
Indexable
    tiles = 2
    tile = img.shape[2] // tiles
    imgs = []
    for i in range(tiles):
        for j in range(tiles):
            imgs.append(img[:, :, i * tile:(i + 1) * tile, j * tile:(j + 1) * tile])

    # change bboxes:
    # if a bbox is in the image, it has to have all coordinates > 0 and x2 and y2 < image size
    bboxes_ = []
    for i in range(tiles):
        for j in range(tiles):
            bboxes_.append(bboxes - torch.tensor([j * tile, i * tile, j * tile, i * tile]).to(device))

    # check on which tile a bbox is
    for i in range(tiles*tiles):
        bboxes_[i] = bboxes_[i][torch.logical_not(
            torch.logical_or((bboxes_[i] < 0).any(dim=2), (bboxes_[i] > imgs[i].shape[3]).any(dim=2)))]



    # pad bboxes to have the same number of bboxes
    max_num_bboxes = max([bboxes_[i].shape[0] for i in range(tiles*tiles)])
    for i in range(tiles*tiles):
        bboxes_[i] = torch.cat([bboxes_[i], torch.zeros((max_num_bboxes - bboxes_[i].shape[0], 4)).to(device)])

    bboxes_batch = torch.stack([bb for bb in bboxes_])*tiles
    img_batch = torch.stack([im[0] for im in imgs])

    # Upscale image tensor to 1024,1024
    img_batch = torch.nn.functional.interpolate(img_batch, size=1024, mode='bilinear', align_corners=False)

    outputs, ref_points, centerness, outputs_coord, masks = model(img_batch, bboxes_batch)


    for _ in range(tiles*tiles):
        if outputs[_]['pred_boxes'].shape[-1] != 4:
            outputs[_]['pred_boxes'] = torch.zeros((1,1, 4)).to(device)
            outputs[_]['scores'] = torch.zeros((1,1)).to(device)
            outputs[_]['box_v'] = torch.zeros((1,1)).to(device)

    merge_bboxes = []
    for i in range(tiles*tiles):
        merge_bboxes.append((outputs[i]["pred_boxes"] / tiles + torch.tensor([(i % tiles) * img.shape[3] // tiles, (i // tiles) * img.shape[2] // tiles, (i % tiles) * img.shape[3] // tiles, (i // tiles) * img.shape[2] // tiles]).to(device) / 1024)[0])
    bboxes_pred = torch.cat(merge_bboxes)

    scores = torch.cat([outputs[i]["scores"][0] for i in range(tiles*tiles)])
    box_v = torch.cat([outputs[i]["box_v"][0] for i in range(tiles*tiles)])
    outputs = [{"pred_boxes": bboxes_pred, "scores": scores, "box_v": box_v}]
    
    
    
    
    # po ekstrakciji prototipov s featurjev pa nekaj takega:
    shape = shape[exemplars.norm(dim=-1) > 0]
    shape = shape.repeat(bs, 1, 1)
    exemplars = exemplars[exemplars.norm(dim=-1) > 0]
    exemplars = exemplars.repeat(bs, 1, 1)

    prototype_embeddings = torch.cat([exemplars, shape], dim=1)
Editor is loading...
Leave a Comment