Untitled
unknown
plain_text
a year ago
10 kB
3
Indexable
Never
import os import re import tempfile import requests import logging import pickle import codecs import traceback as tb import uuid from skimage.transform import resize from skimage.measure import label import gc import tensorflow.compat.v1 as tf tf.disable_v2_behavior() tf.device('GPU:0') from tensorflow.keras.models import load_model import nibabel as nib import numpy as np import SimpleITK as sitk from dynaconf import settings from os.path import isfile from scipy import ndimage import tensorflow as tf _MODELS_FOLDER = os.getenv('MODELS_FOLDER') sitk.ProcessObject_GlobalDefaultDebugOff() sitk.ProcessObject_GlobalWarningDisplayOff() TASKS_SERVER_HOSTNAME = os.environ.get('TASKS_SERVER_HOSTNAME') TASKS_SERVER_PORT = os.environ.get('TASKS_SERVER_PORT') TASKS_SERVER_URL = f"http://{TASKS_SERVER_HOSTNAME}:{TASKS_SERVER_PORT}/" logging.basicConfig(level=logging.INFO) class DataGenerator(tf.keras.utils.Sequence): 'Generates data for Keras' def __init__(self, niftis, masks, batch_size=10, dim=(256, 256, 50), n_channels=1, n_classes=10, shuffle=True, mask_down_resize_order: int = 3): 'Initialization' self.niftis = niftis self.batch_size = batch_size self.masks = masks self.list_IDs = list(range(len(niftis))) self.n_channels = n_channels self.n_classes = n_classes self.shuffle = shuffle self.on_epoch_end() self.dim = dim self.mask_down_resize_order = mask_down_resize_order def __len__(self): 'Denotes the number of batches per epoch' if len(self.niftis) % self.batch_size == 0: return (len(self.niftis) // self.batch_size) else: return (len(self.niftis) // self.batch_size) - 1 def __getitem__(self, index): 'Generate one batch of data' # Generate indexes of the batch # if (index+1) * self.batch_size > len(self.indexes): # indexes = self.indexes[index * self.batch_size:-1] # else: indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] # Find list of IDs niftis_temp = [self.niftis[k] for k in indexes] masks_temp = [self.masks[k] for k in indexes] # Generate data X, y = self.__data_generation(niftis_temp, masks_temp) return X, y def on_epoch_end(self): 'Updates indexes after each epoch' self.indexes = np.arange(len(self.list_IDs)) if self.shuffle == True: np.random.shuffle(self.indexes) @staticmethod def remove_small_components(mask, minimum_cc_sum=100): labelled_mask, num_labels = ndimage.label(mask) if num_labels > 1: print('removing small components') for label in range(num_labels): if np.sum(mask[labelled_mask == label]) < minimum_cc_sum: mask[labelled_mask == label] = 0 return mask def __data_generation(self, niftis_temp, masks_temp): 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) # Initialization X = np.empty((self.batch_size, self.n_channels, *self.dim)) y = np.empty((self.batch_size, self.n_channels, *self.dim), dtype=int) # Generate data for i, ID in enumerate(niftis_temp): print(" == > masks_temp[i] " , masks_temp[i]) print(" == > ID " , ID) # if ID == masks_temp[i]: # raise Exception(f'The case path and the mask case path are identical: {ID}') # checking if the sample was loaded before in order to save run-time of resizing case_resized_path = f'{ID.replace(".nii.gz", "_resized.nii.gz")}' mask_resized_path = f'{masks_temp[i].replace(".nii.gz", "_resized.nii.gz")}' if isfile(case_resized_path) and isfile(mask_resized_path): # load resized case case_resized = nib.load(case_resized_path).get_fdata() mask_case_resized = nib.load(mask_resized_path).get_fdata() X[i, ] = case_resized y[i, ] = mask_case_resized continue # Store sample nifti_file = nib.load(ID) nifti_file = nib.as_closest_canonical(nifti_file) case = nifti_file.get_fdata() if np.unique(case).size <= 2: raise Exception(f'It looks like the case file is in fact a mask file. It has les or equal to 2 unique variables: {ID}') case = np.clip(case, -150, 150) case = np.expand_dims(resize(case, self.dim), axis=0) case = (case + abs(case.min())) / (abs(case.max()) + abs(case.min())) X[i, ] = case mask_file = nib.load(masks_temp[i]) mask_file = nib.as_closest_canonical(mask_file) mask_case = mask_file.get_fdata() mask_case = np.clip(mask_case, 0, 1) # if np.any((mask_case > 0) & (mask_case < 1)): # raise Exception(f'It looks like the mask case file is in fact not a mask case. It contains variables different than 0 or 1: {ID}') mask_case = np.expand_dims( resize(mask_case, self.dim, anti_aliasing=False, order=self.mask_down_resize_order), axis=0) mask_case = (mask_case > 0.5).astype(mask_case.dtype) # Store class y[i, ] = mask_case # saving the sample in order to save run-time of resizing next time the sample will be loaded nib.save(nib.Nifti1Image(case, nifti_file.affine), case_resized_path) nib.save(nib.Nifti1Image(mask_case, mask_file.affine), mask_resized_path) return X, y def liver_segmentation_new_cases(CT_Path, ct_cropped, tmp_liver_cropped, model_name): def getLargestCC(segmentation): labels = label(segmentation) assert (labels.max() != 0) # assume at least 1 CC largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 return largestCC params = {'dim': (128, 128, 48), 'batch_size': 1, 'n_classes': 6, 'n_channels': 1, 'shuffle': True} # mdl_after_richard_fix_with_new_augs.h5 # old network - mdl_after_richard_fix.h5 model = load_model(model_name, compile=False) test_generator = DataGenerator([CT_Path], [CT_Path], **params, mask_down_resize_order=3) a, _ = test_generator.__getitem__(0) predictions = model.predict(a) gc.collect() predictions = predictions[0][0] nifti_file = nib.load(CT_Path) logging.warning("shape before as_closest_canonical: %r" % (nifti_file.shape,)) nifti_file = nib.as_closest_canonical(nifti_file) logging.warning("shape after as_closest_canonical: %r" % (nifti_file.shape,)) nib.save(nifti_file, ct_cropped) case = nifti_file.get_fdata() predictions[predictions < 0.5] = 0 predictions[predictions >= 0.5] = 1 predictions = getLargestCC(predictions) predictions = resize(predictions, case.shape, anti_aliasing=False, order=1) predictions = (predictions > 0.5).astype(predictions.dtype) predictions_nifti = nib.Nifti1Image(predictions, nifti_file.affine) nib.save(predictions_nifti, tmp_liver_cropped) def _model_fullpath(model_filename): return os.path.join(_MODELS_FOLDER, model_filename) def _change_the_sizes(ct_file, temp_liver_cropped): liver_file = nib.load(temp_liver_cropped) liver_case = liver_file.get_fdata() z_mm = liver_file.header.get_zooms()[2] nifti_file = nib.load(ct_file) nifti_case = nifti_file.get_fdata() logging.warning("in _change_the_sizes, nifti_case shape before cropping: %r" % (nifti_case.shape,)) x = np.any(liver_case, axis=(1, 2)) y = np.any(liver_case, axis=(0, 2)) z = np.any(liver_case, axis=(0, 1)) xmin, xmax = np.where(x)[0][[0, -1]] ymin, ymax = np.where(y)[0][[0, -1]] zmin, zmax = np.where(z)[0][[0, -1]] rizika = int(np.ceil(1.5 / z_mm)) zmin = max(0, zmin - rizika) zmax = min(nifti_case.shape[2], zmax + rizika) nifti_case = nifti_case[:, :, zmin:zmax] logging.warning("in _change_the_sizes, nifti_case shape after cropping: %r" % (nifti_case.shape,)) nii = nib.Nifti1Image(nifti_case, nifti_file.affine) nib.save(nii, ct_file) os.remove(temp_liver_cropped) def liver_roi_cropping(**kwargs): ct_file = kwargs.get('input_image') ct_output_file = kwargs.get('output_image') temp_cropped_liver = os.path.join(os.path.dirname(ct_file), f"temp-liver-cropped-{os.path.basename(ct_file)}") liver_segmentation_new_cases(ct_file, ct_output_file, temp_cropped_liver, _model_fullpath(settings.MODELS.LIVER_CROPPING)) _change_the_sizes(ct_output_file, temp_cropped_liver) def _get_attachment_filename(response): return re.findall("filename=(.+)", response.headers['content-disposition'])[0] def run(node:dict,input:list,output:str): image_local_filepath = os.path.join(tempdir.name, image_filename) image_file = open(image_local_filepath, "wb") image_file.write(image_file_response.content) image_file.close() output_image_filename = os.path.join(tempdir.name, f"cropped-{image_filename}") liver_roi_cropping(input_image=image_local_filepath, output_image=output_image_filename) cropped_image_file = open(output_image_filename, "rb") output_filename = "cropped.nii.gz" request_files = { "filename": (None, output_filename), "file": ("file", cropped_image_file) } logging.warning(" @@@@@@@@@@@@ cropping liver done. task_id: %r" % task_id) image_sending_response = requests.put(new_files_url, files=request_files) if image_sending_response.status_code != 200: logging.warning(" send files fails %r" % image_sending_response.text) cropped_image_file.close() tempdir.cleanup()