skip_differential_rendering

 avatar
unknown
python
3 years ago
2.9 kB
2
Indexable
    def skip_differential_rendering(self, points, rgbs, vps, K, bn=0, j=0, save=True):
        """
        Inputs:
            points.shape: b x (128*128) x 3 = b x 16384 x 3
            rgbs.shape: b x (128*128) x 3 = b x 16384 x 3
            k.shape: b x 3 x 3
        Outputs:
            batch_depth, shape : (b x 1 x 128 x 128)
            batch_rgb, shape : (b x 3 x 128 x 128)
            batch_mask, shape : (b x 1 x 128 x 128)
        """
        
        batch_size = points.shape[0]
        # shapes: b x 3 x 16384
        points = points.permute(0, 2, 1)
        rgbs = rgbs.permute(0, 2, 1)

        multiplier = 1
        image = K.bmm(points)

        image_u = image[:, 0, :] * multiplier / image[:, 2, :]
        image_v = image[:, 1, :] * multiplier / image[:, 2, :]
        Z_val = image[:, 2, :]

        Z_val_Color = rgbs
        Z_val_Color = (Z_val_Color - Z_val_Color.min()) / (Z_val_Color.max() - Z_val_Color.min())

        H = 128 #   JUST HARD CODE. sqrt is expensive. Or better yet, add the image size to the config.
        # H = int(math.sqrt(image_u.shape[1]))  # sqrt(16384) -> 128
        W = H
        
        depth_img_from_projection = torch.zeros([batch_size,H,W]).cuda()
        mask_from_projection = torch.zeros([batch_size,H,W]).cuda()
        color_img_from_projection = torch.zeros([batch_size,H,W,3]).cuda()


        # Since each image's valid pixels count will be different, the following does not seem to be vectorizable.
        for b in range(batch_size):

            valid_indices = (image_u[b,:]>=0) & (image_u[b, :] < W) & (image_v[b,:] >=0) & (image_v[b,:] < H)
            valid_u_inds = image_u[b, valid_indices].to(torch.int64)
            valid_v_inds = image_v[b, valid_indices].to(torch.int64)
        
            depth_img_from_projection[b, valid_u_inds[:], valid_v_inds[:]] = Z_val[b, valid_indices[:]].float()    
            color_img_from_projection[b, valid_u_inds[:], valid_v_inds[:], :] = Z_val_Color.permute((0,2,1))[b, valid_indices, :].float()

            if save:
                plt.imsave("./SelfRenders/img-" + str(bn) + "_" + str(b) + "_" + str(j + 1) + ".png",
                            depth_img_from_projection[b].cpu().detach().numpy(), cmap="hot")
                plt.imsave("./SelfRenders/img-" + str(bn) + "_" + str(b) + "_" + str(j + 1) + "c.png",
                            color_img_from_projection[b].cpu().detach().numpy())
            
            mask_from_projection[b] = (depth_img_from_projection[b] != 0).to(torch.int64)

        dep_batch = depth_img_from_projection[:, None, :, :]
        mask_batch = mask_from_projection[:, None, :, :]
        color_batch = color_img_from_projection.permute(0, 3, 1, 2)
                
        return dep_batch.cuda(), color_batch.cuda(), mask_batch.cuda()     # (b x 1 x 128 x 128) and (b x 3 x 128 x 128)
Editor is loading...