Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
10 kB
2
Indexable
Never
#!/usr/bin/env python3
import json
import os
import re
import tempfile
import requests
import logging
import pickle
import codecs
import traceback as tb
import uuid
from skimage.transform import resize
from skimage.measure import label
import gc
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.device('GPU:0')
from tensorflow.keras.models import load_model
import nibabel as nib
import numpy as np
import SimpleITK as sitk
from dynaconf import settings
from os.path import isfile
from scipy import ndimage
import tensorflow as tf


CROPPING_MODEL_PATH = os.getenv('CROPPING_MODEL_PATH')

TASKS_SERVER_HOSTNAME = os.environ.get('TASKS_SERVER_HOSTNAME')
TASKS_SERVER_PORT = os.environ.get('TASKS_SERVER_PORT')
TASKS_SERVER_URL = f"http://{TASKS_SERVER_HOSTNAME}:{TASKS_SERVER_PORT}/"

sitk.ProcessObject_GlobalDefaultDebugOff()
sitk.ProcessObject_GlobalWarningDisplayOff()



logging.basicConfig(level=logging.INFO)


class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(self, niftis, masks, batch_size=10, dim=(256, 256, 50), n_channels=1,
                 n_classes=10, shuffle=True, mask_down_resize_order: int = 3):
        'Initialization'
        self.niftis = niftis
        self.batch_size = batch_size
        self.masks = masks
        self.list_IDs = list(range(len(niftis)))
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()
        self.dim = dim
        self.mask_down_resize_order = mask_down_resize_order

    def __len__(self):
        'Denotes the number of batches per epoch'
        if len(self.niftis) % self.batch_size == 0:
            return (len(self.niftis) // self.batch_size)
        else:
            return (len(self.niftis) // self.batch_size) - 1

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        # if (index+1) * self.batch_size > len(self.indexes):
        #     indexes = self.indexes[index * self.batch_size:-1]
        # else:
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Find list of IDs
        niftis_temp = [self.niftis[k] for k in indexes]
        masks_temp = [self.masks[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(niftis_temp, masks_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    @staticmethod
    def remove_small_components(mask, minimum_cc_sum=100):
        labelled_mask, num_labels = ndimage.label(mask)
        if num_labels > 1:
            print('removing small components')
        for label in range(num_labels):
            if np.sum(mask[labelled_mask == label]) < minimum_cc_sum:
                mask[labelled_mask == label] = 0
        return mask

    def __data_generation(self, niftis_temp, masks_temp):
        'Generates data containing batch_size samples'  # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, self.n_channels, *self.dim))
        y = np.empty((self.batch_size, self.n_channels, *self.dim), dtype=int)

        # Generate data
        for i, ID in enumerate(niftis_temp):
            print(" == > masks_temp[i] " , masks_temp[i])
            print(" == > ID " , ID)
            # if ID == masks_temp[i]:
            #     raise Exception(f'The case path and the mask case path are identical: {ID}')

            # checking if the sample was loaded before in order to save run-time of resizing
            case_resized_path = f'{ID.replace(".nii.gz", "_resized.nii.gz")}'
            mask_resized_path = f'{masks_temp[i].replace(".nii.gz", "_resized.nii.gz")}'
            if isfile(case_resized_path) and isfile(mask_resized_path):
                # load resized case
                case_resized = nib.load(case_resized_path).get_fdata()
                mask_case_resized = nib.load(mask_resized_path).get_fdata()
                X[i, ] = case_resized
                y[i, ] = mask_case_resized
                continue

            # Store sample
            nifti_file = nib.load(ID)

            nifti_file = nib.as_closest_canonical(nifti_file)

            case = nifti_file.get_fdata()

            if np.unique(case).size <= 2:
                raise Exception(f'It looks like the case file is in fact a mask file. It has les or equal to 2 unique variables: {ID}')

            case = np.clip(case, -150, 150)
            case = np.expand_dims(resize(case, self.dim), axis=0)

            case = (case + abs(case.min())) / (abs(case.max()) + abs(case.min()))

            X[i, ] = case

            mask_file = nib.load(masks_temp[i])
            mask_file = nib.as_closest_canonical(mask_file)
            mask_case = mask_file.get_fdata()
            mask_case = np.clip(mask_case, 0, 1)

            # if np.any((mask_case > 0) & (mask_case < 1)):
            #     raise Exception(f'It looks like the mask case file is in fact not a mask case. It contains variables different than 0 or 1: {ID}')

            mask_case = np.expand_dims(
                resize(mask_case, self.dim, anti_aliasing=False, order=self.mask_down_resize_order), axis=0)
            mask_case = (mask_case > 0.5).astype(mask_case.dtype)

            # Store class
            y[i, ] = mask_case

            # saving the sample in order to save run-time of resizing next time the sample will be loaded
            nib.save(nib.Nifti1Image(case, nifti_file.affine), case_resized_path)
            nib.save(nib.Nifti1Image(mask_case, mask_file.affine), mask_resized_path)

        return X, y


def liver_segmentation_new_cases(CT_Path, ct_cropped, tmp_liver_cropped, model_name):
    def getLargestCC(segmentation):
        labels = label(segmentation)
        assert (labels.max() != 0)  # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1
        return largestCC

    params = {'dim': (128, 128, 48),
              'batch_size': 1,
              'n_classes': 6,
              'n_channels': 1,
              'shuffle': True}

    # mdl_after_richard_fix_with_new_augs.h5
    # old network - mdl_after_richard_fix.h5
    model = load_model(model_name, compile=False)
    
    test_generator = DataGenerator([CT_Path], [CT_Path], **params, mask_down_resize_order=3)
    a, _ = test_generator.__getitem__(0)

    predictions = model.predict(a)
    gc.collect()
    predictions = predictions[0][0]

    nifti_file = nib.load(CT_Path)

    logging.warning("shape before as_closest_canonical: %r" % (nifti_file.shape,))
    nifti_file = nib.as_closest_canonical(nifti_file)
    logging.warning("shape after as_closest_canonical: %r" % (nifti_file.shape,))

    nib.save(nifti_file, ct_cropped)

    case = nifti_file.get_fdata()

    predictions[predictions < 0.5] = 0
    predictions[predictions >= 0.5] = 1

    predictions = getLargestCC(predictions)

    predictions = resize(predictions, case.shape, anti_aliasing=False, order=1)
    predictions = (predictions > 0.5).astype(predictions.dtype)

    predictions_nifti = nib.Nifti1Image(predictions, nifti_file.affine)
    nib.save(predictions_nifti, tmp_liver_cropped)




def _change_the_sizes(ct_file, temp_liver_cropped):
    liver_file = nib.load(temp_liver_cropped)
    liver_case = liver_file.get_fdata()
    z_mm = liver_file.header.get_zooms()[2]

    nifti_file = nib.load(ct_file)
    nifti_case = nifti_file.get_fdata()
    logging.warning("in _change_the_sizes, nifti_case shape before cropping: %r" % (nifti_case.shape,))

    x = np.any(liver_case, axis=(1, 2))
    y = np.any(liver_case, axis=(0, 2))
    z = np.any(liver_case, axis=(0, 1))

    xmin, xmax = np.where(x)[0][[0, -1]]
    ymin, ymax = np.where(y)[0][[0, -1]]
    zmin, zmax = np.where(z)[0][[0, -1]]

    rizika = int(np.ceil(1.5 / z_mm))
    zmin = max(0, zmin - rizika)
    zmax = min(nifti_case.shape[2], zmax + rizika)
    nifti_case = nifti_case[:, :, zmin:zmax]
    logging.warning("in _change_the_sizes, nifti_case shape after cropping: %r" % (nifti_case.shape,))

    nii = nib.Nifti1Image(nifti_case, nifti_file.affine)
    nib.save(nii, ct_file)

    os.remove(temp_liver_cropped)


def liver_roi_cropping(**kwargs):
    ct_file = kwargs.get('input_image')
    ct_output_file = kwargs.get('output_image')
    temp_cropped_liver = os.path.join(os.path.dirname(ct_file), f"temp-liver-cropped-{os.path.basename(ct_file)}")

    liver_segmentation_new_cases(ct_file, ct_output_file, temp_cropped_liver, CROPPING_MODEL_PATH)
    _change_the_sizes(ct_output_file, temp_cropped_liver)


def _get_attachment_filename(response):
    return re.findall("filename=(.+)", response.headers['content-disposition'])[0]


def run(node:dict,inputs:dict,output:str):

    with tempfile.TemporaryDirectory() as tempdir:
        input_file_name = "input.nii.gz"
        output_filename = "cropped.nii.gz"

        input_local_filepath = os.path.join(tempdir.name, input_file_name)
        output_local_filepath = os.path.join(tempdir.name, output_filename)


        input_url = TASKS_SERVER_URL + inputs[input_file_name]
        output_url = TASKS_SERVER_URL + output


        input_file_response = requests.get(url=input_url)
        input_file_response.raise_for_status()

        with open(input_local_filepath, "wb") as input_file:
            input_file.write(input_file_response.content)

        liver_roi_cropping(input_image=input_local_filepath, output_image=output_local_filepath)

        with open(output_local_filepath, "rb") as output_file:
            output_file_response = requests.post(url=output_url, files={"file":(output_filename,output_file)})
            output_file_response.raise_for_status()


if __name__ == '__main__':
    data=json.loads(os.argv[1])
    run(**data)