#!/usr/bin/env python3
import json
import os
import re
import tempfile
import requests
import logging
import pickle
import codecs
import traceback as tb
import uuid
from skimage.transform import resize
from skimage.measure import label
import gc
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.device('GPU:0')
from tensorflow.keras.models import load_model
import nibabel as nib
import numpy as np
import SimpleITK as sitk
from dynaconf import settings
from os.path import isfile
from scipy import ndimage
import tensorflow as tf
CROPPING_MODEL_PATH = os.getenv('CROPPING_MODEL_PATH')
TASKS_SERVER_HOSTNAME = os.environ.get('TASKS_SERVER_HOSTNAME')
TASKS_SERVER_PORT = os.environ.get('TASKS_SERVER_PORT')
TASKS_SERVER_URL = f"http://{TASKS_SERVER_HOSTNAME}:{TASKS_SERVER_PORT}/"
sitk.ProcessObject_GlobalDefaultDebugOff()
sitk.ProcessObject_GlobalWarningDisplayOff()
logging.basicConfig(level=logging.INFO)
class DataGenerator(tf.keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, niftis, masks, batch_size=10, dim=(256, 256, 50), n_channels=1,
n_classes=10, shuffle=True, mask_down_resize_order: int = 3):
'Initialization'
self.niftis = niftis
self.batch_size = batch_size
self.masks = masks
self.list_IDs = list(range(len(niftis)))
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.on_epoch_end()
self.dim = dim
self.mask_down_resize_order = mask_down_resize_order
def __len__(self):
'Denotes the number of batches per epoch'
if len(self.niftis) % self.batch_size == 0:
return (len(self.niftis) // self.batch_size)
else:
return (len(self.niftis) // self.batch_size) - 1
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
# if (index+1) * self.batch_size > len(self.indexes):
# indexes = self.indexes[index * self.batch_size:-1]
# else:
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
# Find list of IDs
niftis_temp = [self.niftis[k] for k in indexes]
masks_temp = [self.masks[k] for k in indexes]
# Generate data
X, y = self.__data_generation(niftis_temp, masks_temp)
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)
@staticmethod
def remove_small_components(mask, minimum_cc_sum=100):
labelled_mask, num_labels = ndimage.label(mask)
if num_labels > 1:
print('removing small components')
for label in range(num_labels):
if np.sum(mask[labelled_mask == label]) < minimum_cc_sum:
mask[labelled_mask == label] = 0
return mask
def __data_generation(self, niftis_temp, masks_temp):
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
# Initialization
X = np.empty((self.batch_size, self.n_channels, *self.dim))
y = np.empty((self.batch_size, self.n_channels, *self.dim), dtype=int)
# Generate data
for i, ID in enumerate(niftis_temp):
print(" == > masks_temp[i] " , masks_temp[i])
print(" == > ID " , ID)
# if ID == masks_temp[i]:
# raise Exception(f'The case path and the mask case path are identical: {ID}')
# checking if the sample was loaded before in order to save run-time of resizing
case_resized_path = f'{ID.replace(".nii.gz", "_resized.nii.gz")}'
mask_resized_path = f'{masks_temp[i].replace(".nii.gz", "_resized.nii.gz")}'
if isfile(case_resized_path) and isfile(mask_resized_path):
# load resized case
case_resized = nib.load(case_resized_path).get_fdata()
mask_case_resized = nib.load(mask_resized_path).get_fdata()
X[i, ] = case_resized
y[i, ] = mask_case_resized
continue
# Store sample
nifti_file = nib.load(ID)
nifti_file = nib.as_closest_canonical(nifti_file)
case = nifti_file.get_fdata()
if np.unique(case).size <= 2:
raise Exception(f'It looks like the case file is in fact a mask file. It has les or equal to 2 unique variables: {ID}')
case = np.clip(case, -150, 150)
case = np.expand_dims(resize(case, self.dim), axis=0)
case = (case + abs(case.min())) / (abs(case.max()) + abs(case.min()))
X[i, ] = case
mask_file = nib.load(masks_temp[i])
mask_file = nib.as_closest_canonical(mask_file)
mask_case = mask_file.get_fdata()
mask_case = np.clip(mask_case, 0, 1)
# if np.any((mask_case > 0) & (mask_case < 1)):
# raise Exception(f'It looks like the mask case file is in fact not a mask case. It contains variables different than 0 or 1: {ID}')
mask_case = np.expand_dims(
resize(mask_case, self.dim, anti_aliasing=False, order=self.mask_down_resize_order), axis=0)
mask_case = (mask_case > 0.5).astype(mask_case.dtype)
# Store class
y[i, ] = mask_case
# saving the sample in order to save run-time of resizing next time the sample will be loaded
nib.save(nib.Nifti1Image(case, nifti_file.affine), case_resized_path)
nib.save(nib.Nifti1Image(mask_case, mask_file.affine), mask_resized_path)
return X, y
def liver_segmentation_new_cases(CT_Path, ct_cropped, tmp_liver_cropped, model_name):
def getLargestCC(segmentation):
labels = label(segmentation)
assert (labels.max() != 0) # assume at least 1 CC
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1
return largestCC
params = {'dim': (128, 128, 48),
'batch_size': 1,
'n_classes': 6,
'n_channels': 1,
'shuffle': True}
# mdl_after_richard_fix_with_new_augs.h5
# old network - mdl_after_richard_fix.h5
model = load_model(model_name, compile=False)
test_generator = DataGenerator([CT_Path], [CT_Path], **params, mask_down_resize_order=3)
a, _ = test_generator.__getitem__(0)
predictions = model.predict(a)
gc.collect()
predictions = predictions[0][0]
nifti_file = nib.load(CT_Path)
logging.warning("shape before as_closest_canonical: %r" % (nifti_file.shape,))
nifti_file = nib.as_closest_canonical(nifti_file)
logging.warning("shape after as_closest_canonical: %r" % (nifti_file.shape,))
nib.save(nifti_file, ct_cropped)
case = nifti_file.get_fdata()
predictions[predictions < 0.5] = 0
predictions[predictions >= 0.5] = 1
predictions = getLargestCC(predictions)
predictions = resize(predictions, case.shape, anti_aliasing=False, order=1)
predictions = (predictions > 0.5).astype(predictions.dtype)
predictions_nifti = nib.Nifti1Image(predictions, nifti_file.affine)
nib.save(predictions_nifti, tmp_liver_cropped)
def _change_the_sizes(ct_file, temp_liver_cropped):
liver_file = nib.load(temp_liver_cropped)
liver_case = liver_file.get_fdata()
z_mm = liver_file.header.get_zooms()[2]
nifti_file = nib.load(ct_file)
nifti_case = nifti_file.get_fdata()
logging.warning("in _change_the_sizes, nifti_case shape before cropping: %r" % (nifti_case.shape,))
x = np.any(liver_case, axis=(1, 2))
y = np.any(liver_case, axis=(0, 2))
z = np.any(liver_case, axis=(0, 1))
xmin, xmax = np.where(x)[0][[0, -1]]
ymin, ymax = np.where(y)[0][[0, -1]]
zmin, zmax = np.where(z)[0][[0, -1]]
rizika = int(np.ceil(1.5 / z_mm))
zmin = max(0, zmin - rizika)
zmax = min(nifti_case.shape[2], zmax + rizika)
nifti_case = nifti_case[:, :, zmin:zmax]
logging.warning("in _change_the_sizes, nifti_case shape after cropping: %r" % (nifti_case.shape,))
nii = nib.Nifti1Image(nifti_case, nifti_file.affine)
nib.save(nii, ct_file)
os.remove(temp_liver_cropped)
def liver_roi_cropping(**kwargs):
ct_file = kwargs.get('input_image')
ct_output_file = kwargs.get('output_image')
temp_cropped_liver = os.path.join(os.path.dirname(ct_file), f"temp-liver-cropped-{os.path.basename(ct_file)}")
liver_segmentation_new_cases(ct_file, ct_output_file, temp_cropped_liver, CROPPING_MODEL_PATH)
_change_the_sizes(ct_output_file, temp_cropped_liver)
def _get_attachment_filename(response):
return re.findall("filename=(.+)", response.headers['content-disposition'])[0]
def run(node:dict,inputs:dict,output:str):
with tempfile.TemporaryDirectory() as tempdir:
input_file_name = "input.nii.gz"
output_filename = "cropped.nii.gz"
input_local_filepath = os.path.join(tempdir.name, input_file_name)
output_local_filepath = os.path.join(tempdir.name, output_filename)
input_url = TASKS_SERVER_URL + inputs[input_file_name]
output_url = TASKS_SERVER_URL + output
input_file_response = requests.get(url=input_url)
input_file_response.raise_for_status()
with open(input_local_filepath, "wb") as input_file:
input_file.write(input_file_response.content)
liver_roi_cropping(input_image=input_local_filepath, output_image=output_local_filepath)
with open(output_local_filepath, "rb") as output_file:
output_file_response = requests.post(url=output_url, files={"file":(output_filename,output_file)})
output_file_response.raise_for_status()
if __name__ == '__main__':
data=json.loads(os.argv[1])
run(**data)