Untitled
unknown
python
9 months ago
613 B
7
Indexable
def __getitem__(self, idx):
img_path = self.image_paths[idx]
mask_path = self.mask_paths[idx]
image = np.array(Image.open(img_path).convert('RGB'))
mask = np.array(Image.open(mask_path).convert('L'))
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
mask = torch.from_numpy(mask).unsqueeze(0).float()
mask = (mask > 128).float()
return image, mask.squeeze(0)Editor is loading...
Leave a Comment