Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
1.2 kB
3
Indexable
Never
def __getitem__(self, idx):
        row = list(self.df.loc[idx])
        path = os.path.join(self.root, row[0], "_".join(row))
        image_sitk = sitk.ReadImage(f'{path}.mhd', sitk.sitkFloat32)

        # get pixel spacing to correct aspect ratio
        spacing = image_sitk.GetSpacing()
        aspect_ratio = spacing[1]/spacing[0]

        # convert to numpy
        image = sitk.GetArrayFromImage(image_sitk) / 255
        mask = sitk.GetArrayFromImage(sitk.ReadImage(f'{path}_gt.mhd', sitk.sitkFloat32))

        # compute aspect ratio of pixel(mm) and image(pixels)
        pixel_aspect = spacing[1] / spacing[0]
        image_aspect = image_sitk.GetHeight() / image_sitk.GetWidth()

        # preprocess image and mask
        image, mask = torch.Tensor(image), torch.Tensor(mask)
        size =  (self.image_size[0], int(image.shape[2]*image_aspect*pixel_aspect))

        image  = resize(image, size, interpolation=InterpolationMode.BICUBIC)
        mask = resize(mask, size, interpolation=InterpolationMode.NEAREST)

        image, mask = center_crop(image, self.image_size), center_crop(mask, self.image_size)
        mask = mask.squeeze()

        return image, mask.to(torch.long)