Untitled
unknown
plain_text
2 years ago
1.2 kB
3
Indexable
Never
def __getitem__(self, idx): row = list(self.df.loc[idx]) path = os.path.join(self.root, row[0], "_".join(row)) image_sitk = sitk.ReadImage(f'{path}.mhd', sitk.sitkFloat32) # get pixel spacing to correct aspect ratio spacing = image_sitk.GetSpacing() aspect_ratio = spacing[1]/spacing[0] # convert to numpy image = sitk.GetArrayFromImage(image_sitk) / 255 mask = sitk.GetArrayFromImage(sitk.ReadImage(f'{path}_gt.mhd', sitk.sitkFloat32)) # compute aspect ratio of pixel(mm) and image(pixels) pixel_aspect = spacing[1] / spacing[0] image_aspect = image_sitk.GetHeight() / image_sitk.GetWidth() # preprocess image and mask image, mask = torch.Tensor(image), torch.Tensor(mask) size = (self.image_size[0], int(image.shape[2]*image_aspect*pixel_aspect)) image = resize(image, size, interpolation=InterpolationMode.BICUBIC) mask = resize(mask, size, interpolation=InterpolationMode.NEAREST) image, mask = center_crop(image, self.image_size), center_crop(mask, self.image_size) mask = mask.squeeze() return image, mask.to(torch.long)