Untitled
def __getitem__(self, idx): img_path = self.image_paths[idx] mask_path = self.mask_paths[idx] image = np.array(Image.open(img_path).convert('RGB')) mask = np.array(Image.open(mask_path).convert('L')) if self.transform: augmented = self.transform(image=image, mask=mask) image = augmented['image'] mask = augmented['mask'] image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0 mask = torch.from_numpy(mask).unsqueeze(0).float() mask = (mask > 128).float() return image, mask.squeeze(0)
Leave a Comment