Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
66 kB
3
Indexable
import os
import re
import tempfile
import requests
import SimpleITK as sitk
import nibabel as nib
import numpy as np
from dynaconf import settings
from time import time, gmtime
from typing import List, Tuple
from itertools import product
from imgaug.augmenters import Identity, Sequential, Affine

# from keras.models import load_model
from nibabel import load, as_closest_canonical
from os.path import basename, dirname
from skimage.transform import resize
from skimage.measure import label
from functools import partial
import warnings
import gc
from scipy.ndimage import binary_fill_holes
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()
tf.device("GPU:0")
from tensorflow.keras.models import load_model
from os import system
from itertools import product
from typing import Iterable, Union, Generator
from imgaug.augmenters import Augmenter, Identity, UnnormalizedBatch, Sequential, Affine
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from imgaug.parameters import Choice

# from utils import display_gray_2D_image
from copy import deepcopy
from nibabel import as_closest_canonical
from multiprocessing import cpu_count

_MODELS_FOLDER = os.getenv("MODELS_FOLDER")

sitk.ProcessObject_GlobalDefaultDebugOff()
sitk.ProcessObject_GlobalWarningDisplayOff()

TASKS_SERVER_HOSTNAME = os.environ.get("TASKS_SERVER_HOSTNAME")
TASKS_SERVER_PORT = os.environ.get("TASKS_SERVER_PORT")
TASKS_SERVER_URL = f"http://{TASKS_SERVER_HOSTNAME}:{TASKS_SERVER_PORT}/"


def augment_target(augment: Augmenter, im: np.ndarray):
    return augment(image=im)


class Augmenters(list):
    """
    A container of several augmenters.

    .. Note::
        Some of the methods of this class (especially the methods that are being applied on 2 images and 2 segmentations)
        supports only augmenters that change the image geometry (like flips, affine transformation etc.).
    """

    def __init__(self, seq=(), n=None, seed=None):
        super().__init__(seq)
        if n is not None:
            num_of_augs_available = len(self)

            if n >= num_of_augs_available:
                return

            # choosing n random augmenters from the container.
            rand = np.random.RandomState(seed=seed)
            indexes = rand.choice(num_of_augs_available, n, replace=False)

            self.__init__([self[i] for i in indexes])

    def apply_all_on_image(
        self, image: np.ndarray, original_also: bool = False
    ) -> Generator[Tuple[np.ndarray, List[Tuple[str, dict]]], None, None]:
        """
        Applying all the augmenters in the container on the given image.

        :param image: An image (ndarray).
        :param original_also: If set to true, the first instance in the returned result list wil be the original given
            image (without any augmentation) and after it will be all the augmented images, otherwise the returned
            result list will contain only the augmented images.

        :return: For every augmentation aug_i in the current collection of augmenters, the function yields a tuple in
            the following form: (img_aug_i, aug_i_info_list), where:
            • img_aug_i is the augmented image that was created by aug_i(image), and
            • aug_i_info_list is an ordered list containing the information of the augmentations that
              are included in aug_i (see the returned value of the current class's method 'get_determenistic_augs_info').
        """

        self.__raise_exception_if_empty()

        if original_also:
            yield image.copy(), self.get_deterministic_augs_info(Identity())

        for augment in self:
            yield augment(image=image), self.get_deterministic_augs_info(augment)

    def apply_all_on_image_multiprocessed(
        self, image: np.ndarray, original_also: bool = False
    ) -> List[Tuple[np.ndarray, List[Tuple[str, dict]]]]:
        """
        Applying all the augmenters in the container on the given image. The function performs all the augmentations by
        multiprocessing them.

        :param image: An image (ndarray).
        :param original_also: If set to true, the first instance in the returned result list wil be the original given
            image (without any augmentation) and after it will be all the augmented images, otherwise the returned
            result list will contain only the augmented images.

        :return: A list that contains for every augmentation aug_i in the current collection of augmenters, a tuple in
            the following form: (img_aug_i, aug_i_info_list), where:
            • img_aug_i is the augmented image that was created by aug_i(image), and
            • aug_i_info_list is an ordered list containing the information of the augmentations that
              are included in aug_i (see the returned value of the current class's method 'get_determenistic_augs_info').
        """

        self.__raise_exception_if_empty()

        # with Pool(6) as pool:
        result = list(
            zip(
                list(map(partial(augment_target, im=image), self)),
                (self.get_deterministic_augs_info(aug) for aug in self),
            )
        )
        #       result = list(zip(map(partial(augment_target, im=image), self), (self.get_deterministic_augs_info(aug) for aug in self)))

        if original_also:
            result = [
                (image.copy(), self.get_deterministic_augs_info(Identity()))
            ] + result

        return result

    def apply_all_on_image_and_segment(
        self, image: np.ndarray, seg: np.ndarray, original_also: bool = False
    ) -> List[Tuple[np.ndarray, np.ndarray, List[Tuple[str, dict]]]]:
        """
        Applying all the augmenters in the container on the given image and on its given segmentation.

        :param image: An image (ndarray).
        :param seg: A segmentation (ndarray). It will be casted to int32.
        :param original_also: If set to true, the first instance in the returned result list wil be the original given
            image and segmentation (without any augmentation), otherwise the returned result list will contain only the
            augmented images and segmentations.

        :return: A list that contains for every augmentation aug_i in the current collection of augmenters,
            a tuple in the following form: (img_aug_i, seg_aug_i, aug_i_info_list), where:
            • img_aug_i is the augmented image that was created by aug_i(image), and
            • seg_aug_i is the augmented segmentation that was created by aug_i(seg), and
            • aug_i_info_list is an ordered list containing the information of the augmentations that
              are included in aug_i (see the returned value of the current class's method 'get_determenistic_augs_info').
        """

        seg_map = SegmentationMapsOnImage(
            arr=seg.copy().astype(np.int32), shape=image.shape
        )

        result = [
            self.apply_ith_aug_on_image_and_segment(image, seg_map, i)
            for i in range(len(self))
        ]

        if original_also:
            result = [
                (image.copy(), seg.copy(), self.get_deterministic_augs_info(Identity()))
            ] + result

        return result

    def apply_ith_aug_on_image_and_segment(
        self, image: np.ndarray, seg: Union[np.ndarray, SegmentationMapsOnImage], i: int
    ) -> Tuple[np.ndarray, np.ndarray, List[Tuple[str, dict]]]:
        """
        Applying the i-th augmenter in the container on the given image and on its given segmentation.

        :param image: An image (ndarray).
        :param seg: A segmentation (ndarray (It will be casted to int32.) or an imgaug.augmentables.segmaps.SegmentationMapsOnImage object).
        :param i: A index in the half closed range [0, len(self) ).

        .. note::
            image and segment have to be of the same shape.

        :return: A tuple in the following form: (img_aug, seg_aug, aug_info_list), where:
            • img_aug is the augmented image that was created by aug(image), where 'aug' is the i-th augmenter, and
            • seg_aug is the augmented segmentation that was created by aug(seg), where 'aug' is the i-th augmenter, and
            • aug_info_list is an ordered list containing the information of the augmentations that are included in aug,
                where 'aug' is the i-th augmenter (see the returned value of the current class's method 'get_determenistic_augs_info').
        """

        self.__raise_exception_if_empty()

        aug = self[i]

        seg_map = seg
        if isinstance(seg, np.ndarray):
            seg_map = SegmentationMapsOnImage(
                arr=seg.astype(np.int32), shape=image.shape
            )
        image_aug, seg_map_aug = aug(image=image, segmentation_maps=seg_map)
        seg_aug = seg_map_aug.get_arr()
        return image_aug, seg_aug, self.get_deterministic_augs_info(aug)

    def apply_random_aug_on_image_and_segment(
        self, image: np.ndarray, seg: np.ndarray, seed: int = None
    ) -> Tuple[np.ndarray, np.ndarray, List[Tuple[str, dict]]]:
        """
        Applying a random augmenter from the container on the given image and on its given segmentation.

        :param image: An image (ndarray).
        :param seg: A segmentation (ndarray). It will be casted to int32.
        :param seed: The seed for the randomness (if seed is set to None (by default), the randomness will be truly "random").

        .. note::
            image and segment have to be of the same shape.

        :return: A tuple in the following form: (img_aug, seg_aug, aug_info_list), where:
            • img_aug is the augmented image that was created by aug(image), where 'aug' is the random chosen augmenter, and
            • seg_aug is the augmented segmentation that was created by aug(seg), where 'aug' is the random chosen augmenter, and
            • aug_info_list is an ordered list containing the information of the augmentations that are included in aug,
                where 'aug' is the random chosen augmenter (see the returned value of the current class's method
                'get_determenistic_augs_info').
        """

        # choosing a random augmenter from the container.
        rand = np.random.RandomState(seed=seed)
        rand_index = rand.choice(np.arange(len(self)))

        return self.apply_ith_aug_on_image_and_segment(image, seg, rand_index)

    def apply_random_aug_on_2images_and_2segments(
        self,
        img1: np.ndarray,
        img2: np.ndarray,
        seg1: np.ndarray,
        seg2: np.ndarray,
        seed: int = None,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[Tuple[str, dict]]]:
        """
        Applying a random augmenter from the container on the given 2 images and on their given 2 segmentations.

        :param img1: An image (ndarray, dtype=float) with content in range [0,1] and relevant resolution not above 500.
        :param img2: An image (ndarray, dtype=float) with content in range [0,1] and relevant resolution not above 500.
        :param seg1: A segmentation (ndarray) with content in {0,1} with dtype bool or int. It will be casted to int32.
        :param seg2: A segmentation (ndarray) with content in {0,1} with dtype bool or int. It will be casted to int32.
        :param seed: The seed for the randomness (if seed is set to None (by default), the randomness will be truly "random").

        .. note::
            All the given image/segment arguments (img1, img2, seg1 and seg2) have to be of the same shape.

            In addition, in order to apply this function, the augmentations in the container must be augmentations that
            have not an affect on the z axis (no flips on z axis etc.).

        :return: A tuple in the following form: (img1_aug, img2_aug, seg1_aug, seg2_aug, aug_info_list), where:
            • img_aug1 and img_aug2 are the augmented images that were created by aug(img1) and aug(img2) respectively,
                where 'aug' is the random chosen augmenter, and
            • seg_aug1 and seg_aug2 are the augmented segmentations that were created by aug(seg1) and aug(seg2)
                respectively, where 'aug' is the random chosen augmenter, and
            • aug_info_list is an ordered list containing the information of the augmentations that are included in aug,
                where 'aug' is the random chosen augmenter (see the returned value of the current class's method
                'get_determenistic_augs_info').
        """

        # taking care of the segmentations and the images in order to perform the same augmentation
        # on both images and on both segmentations.
        images = np.concatenate([img1, img2], axis=2)
        segs = np.concatenate([seg1, seg2], axis=2)

        # augmenting
        images_aug, segs_aug, name_list = self.apply_random_aug_on_image_and_segment(
            images, segs, seed=seed
        )

        # splitting the images and segmentations
        img1_aug = images_aug[..., : img1.shape[2]]
        img2_aug = images_aug[..., img1.shape[2] :]
        seg1_aug = segs_aug[..., : seg1.shape[2]]
        seg2_aug = segs_aug[..., seg1.shape[2] :]

        return img1_aug, img2_aug, seg1_aug, seg2_aug, name_list

    def apply_n_random_augs_on_image(
        self,
        image: np.ndarray,
        n: int = None,
        seed: int = None,
        original_also: bool = False,
    ) -> List[Tuple[np.ndarray, List[Tuple[str, dict]]]]:
        """
        Applying n random augmenters from the container on the given image.

        :param image: An image (ndarray).
        :param n: The number of augmentations to chose from the container in order to apply them on the given image.
            If n is either set to None or is bigger then the number of augmentations available, all the augmentations
            will be applied on the input image.
        :param seed: The seed for the randomness (if seed is set to None (by default), the randomness will be truly
            "random").
        :param original_also: If set to true, the first instance in the returned result list wil be the original given
            image (without any augmentation) and after it will be the 'n' augmented images, otherwise the returned
            result list will contain only the augmented images.

        :return: A list that contains for every random chosen augmentation aug_i, a tuple in the following form:
            (img_aug_i, aug_i_info_list), where:
            • img_aug_i is the augmented image that was created by aug_i(image), and
            • aug_i_info_list is an ordered list containing the information of the augmentations that
              are included in aug_i (see the returned value of the current class's method 'get_determenistic_augs_info').
        """

        num_of_augs_available = len(self)

        if n is None or n > num_of_augs_available:
            return self.apply_all_on_image(image, original_also)

        self.__raise_exception_if_empty()

        # choosing n random augmenters from the container.
        rand = np.random.RandomState(seed=seed)
        indexes = rand.choice(np.arange(num_of_augs_available), n, replace=False)

        result = [
            (self[i](image=image), self.get_deterministic_augs_info(self[i]))
            for i in indexes
        ]

        if original_also:
            result = [
                (image.copy(), self.get_deterministic_augs_info(Identity()))
            ] + result

        return result

    def apply_n_random_augs_on_image_and_segment(
        self,
        image: np.ndarray,
        seg: np.ndarray,
        n: int = None,
        seed: int = None,
        original_also: bool = False,
    ) -> List[Tuple[np.ndarray, np.ndarray, List[Tuple[str, dict]]]]:
        """
        Applying n random augmenters from the container on the given image and on its given segmentation.

        :param image: An image (ndarray).
        :param seg: A segmentation (ndarray). It will be casted to int32.
        :param n: The number of augmentations to chose from the container in order to apply them on the given image and
            segmentation. If n is either set to None or is bigger then the number of augmentations available, all the
            augmentations will be applied on the input image.
        :param seed: The seed for the randomness (if seed is set to None (by default), the randomness will be truly
            "random").
        :param original_also: If set to true, the first instance in the returned result list wil be the original given
            image segmentation (without any augmentation) and after it will be the 'n' augmented images and
            segmentations, otherwise the returned result list will contain only the augmented images and segmentations.

        :return: A list that contains for every random chosen augmentation aug_i, a tuple in the following form:
            (img_aug_i, seg_aug_i, aug_i_info_list), where:
            • img_aug_i is the augmented image that was created by aug_i(image), and
            • seg_aug_i is the augmented segmentation that was created by aug_i(seg), and
            • aug_i_info_list is an ordered list containing the information of the augmentations that
              are included in aug_i (see the returned value of the current class's method 'get_determenistic_augs_info').
        """

        num_of_augs_available = len(self)

        if n is None or n > num_of_augs_available:
            return self.apply_all_on_image_and_segment(
                image, seg, original_also=original_also
            )

        # choosing n random augmenters from the container.
        rand = np.random.RandomState(seed=seed)
        indexes = rand.choice(np.arange(num_of_augs_available), n, replace=False)

        seg_map = SegmentationMapsOnImage(
            arr=seg.copy().astype(np.int32), shape=image.shape
        )

        result = [
            self.apply_ith_aug_on_image_and_segment(image, seg_map, i) for i in indexes
        ]

        if original_also:
            result = [
                (image.copy(), seg.copy(), self.get_deterministic_augs_info(Identity()))
            ] + result

        return result

    def apply_all_on_array_of_images(
        self,
        array_of_images: np.ndarray,
        original_also: bool = False,
        seed: int = None,
        debug_mode: bool = False,
    ) -> Generator[Tuple[np.ndarray, List[Tuple[str, dict]]], None, None]:
        """
        Applying all the augmenters in the container on the given array of images.

        :param array_of_images: A ndarrays of images. Namely, a 4D ndarray that contains 3D images that need to be
            augmented.
        :param original_also: If set to true, the first instance in the yield result wil be the original given
            image (without any augmentation) and after it will be all the augmented images, otherwise the yield
            result will be only the augmented images.
        :param seed: The seed for the randomness (if seed is set to None (by default), the randomness will be truly
            "random").
        :param debug_mode: If set to False (by default), every yield result will be multiprocessed, otherwise they won't.

        :return: For every augmentation aug_i, the function yields a tuple in the following form:
            (array_of_images_aug_i, aug_i_info_list), where:
            • array_of_images_aug_i is a 4D ndarray of the augmented images that was created by aug_i(array_of_images),
            • aug_i_info_list is an ordered list containing the information of the augmentations that
              are included in aug_i (see the returned value of the current class's method 'get_determenistic_augs_info').
        """

        self.__raise_exception_if_empty()

        if original_also:
            yield array_of_images, self.get_deterministic_augs_info(Identity())

        # splitting the images into batches for multiprocessing purpose
        batches = self.__create_lists_of_batches(array_of_images)

        for aug in self:
            yield (
                self.__apply_aug_on_batches(
                    aug, batches, seed=seed, debug_mode=debug_mode
                ),
                self.get_deterministic_augs_info(aug),
            )

    def get_n_random_augmenters(self, n: int, seed: int = None):
        """
        Creating a new container that contains only 'n' augmenters from the current container (chosen randomly).

        :param n: The number of augmenters in the new container. If bigger than the number of augmentations available
            currently in the container, the container itself is being returned.
        :param seed: The seed for the randomness (if seed is set to None (by default), the randomness will be truly
            "random").

        :return: The new container as an Augmenters object.
        """

        num_of_augs_available = len(self)

        if n >= num_of_augs_available:
            return self

        rand = np.random.RandomState(seed=seed)
        indexes = rand.choice(num_of_augs_available, n, replace=False)

        return Augmenters([self[i] for i in indexes])

    def __raise_exception_if_empty(self):
        if not self:
            raise Exception("There isn't any augmentation to apply.")

    @staticmethod
    def __apply_aug_on_batches(
        aug: Augmenter,
        batches: List[UnnormalizedBatch],
        debug_mode: bool = False,
        seed=None,
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """
        Applying a certain augmentation on a list of batches of images.

        :param aug: An Augmenter object.
        :param batches: A list of batches as UnnormalizedBatch objects.
        :param debug_mode: If set to False (by default) the function will be multiprocessed, otherwise it won't.
        :param seed: The seed for the randomness (if seed is set to None (by default), the randomness will be truly
            "random").

        :return: Either:
            • (augmented_images_array, augmented_segmentations_array) - if there is images and segmentations in the given list of batches.
            • augmented_images_array - if there is only images in the given list of batches.
            • augmented_segmentations_array - if there is only segmentations in the given list of batches.
            • None if there is not images and segmentations in the given list of batches.
            Each returned array is a 4D ndarray containing all the augmented images/segmentations, stacked, one after
            another in their order.
        """

        print("[!!!!!!!!!!!] i am here")
        if debug_mode:
            current_seed = aug.random_state
            aug.seed_(seed)
            auged_batches = aug.augment_batches(batches)
            aug.seed_(current_seed)
        else:
            n_processes = cpu_count() // 2
            with aug.pool(processes=n_processes, seed=seed) as pool:
                auged_batches = pool.map_batches(batches)

        return Augmenters.__extract_arrays_from_augmented_batches(auged_batches)

    @staticmethod
    def __extract_arrays_from_augmented_batches(
        auged_batches: List[UnnormalizedBatch],
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """
        Extracting augmented images and or segmentations from the given batches and converting them to 4D ndarrays of
        all the images/segmentations stacked one after another, in their order.

        :param auged_batches: A list of batches as UnnormalizedBatch objects.

        :return: Either:
            • (augmented_images_array, augmented_segmentations_array) - if there is augmented images and augmented segmentations in the given list of batches.
            • augmented_images_array - if there is only augmented images in the given list of batches.
            • augmented_segmentations_array - if there is only augmented segmentations in the given list of batches.
            • None if there is not augmented images and augmented segmentations in the given list of batches.
            Each returned array is a 4D ndarray containing all the augmented images/segmentations, stacked, one after
            another in their order.
        """

        def extract_images(batch):
            return (
                batch.images_aug
                if isinstance(batch.images_aug, np.ndarray)
                else np.stack(batch.images_aug)
                if batch.images_aug is not None
                else None
            )

        def extract_segments(batch):
            return (
                np.stack([seg_map.get_arr() for seg_map in batch.segmentation_maps_aug])
                if batch.segmentation_maps_aug is not None
                else None
            )

        auged_images, auged_segments = zip(
            *[
                (extract_images(batch), extract_segments(batch))
                for batch in auged_batches
            ]
        )

        auged_images = [im for im in auged_images if im is not None]
        auged_images = None if len(auged_images) == 0 else np.concatenate(auged_images)

        auged_segments = [seg for seg in auged_segments if seg is not None]
        auged_segments = (
            None if len(auged_segments) == 0 else np.concatenate(auged_segments)
        )

        if auged_images is not None and auged_segments is not None:
            return auged_images, auged_segments
        if auged_images is not None:
            return auged_images
        if auged_segments is not None:
            return auged_segments

    @staticmethod
    def __create_lists_of_batches(
        array_of_images: np.ndarray = None,
        array_of_segments: np.ndarray = None,
        n_batches: int = None,
    ) -> List[UnnormalizedBatch]:
        """
        Creating a list of batches of images and or segmentations for multiprocessing purpose.

        :param array_of_images: [Optional] A ndarrays of images. Namely, a ndarray of ndim=4 that contains 3D images that need to
            be augmented.
        :param array_of_segments: [Optional] A ndarrays of segmentations. Namely, a ndarray of ndim=4 that contains 3D
            segmentations that need to be augmented.
        :param n_batches: The maximum number of batches to create (if set to None, the default will be:
        3 * n_cpus_available + 1)

        :return: A list of batches as 'UnnormalizedBatch' objects.
        """

        if array_of_images is None and array_of_segments is None:
            raise Exception("There wasn't given any array to create the batches")

        if n_batches is None:
            n_batches = 3 * cpu_count() + 1

        if array_of_segments is None:
            return [
                UnnormalizedBatch(images=batch)
                for batch in np.array_split(array_of_images, n_batches)
                if batch.shape[0] > 0
            ]
        elif array_of_images is None:
            return [
                UnnormalizedBatch(
                    segmentation_maps=[
                        SegmentationMapsOnImage(arr=seg, shape=seg.shape)
                        for seg in segments_batch
                    ]
                )
                for segments_batch in np.array_split(array_of_segments, n_batches)
                if segments_batch.shape[0] > 0
            ]
        else:
            return [
                UnnormalizedBatch(
                    images=images_batch,
                    segmentation_maps=[
                        SegmentationMapsOnImage(arr=seg, shape=seg.shape)
                        for seg in segments_batch
                    ],
                )
                for (images_batch, segments_batch) in zip(
                    np.array_split(array_of_images, n_batches),
                    np.array_split(array_of_segments, n_batches),
                )
                if images_batch.shape[0] > 0
            ]

    @staticmethod
    def get_deterministic_augs_info(augmenter: Augmenter) -> List[Tuple[str, dict]]:
        """
        Creating a list of all the sub-augmenters contained in a single augmenter.

        .. Note::
            • Read carefully the documentation of the current method before eny use or change.
            • Read carefully the documentation of the current class's method 'create_back_augmenter_by_info' before
                any use or change.
            • Please make sure to modify the documentation respectively to any change in the method.
            • Any change in this method must be taken into account in the current class's method
                'create_back_augmenter_by_info'.
            • Not every augmenter's parameters are taken into account. For now, the method supports only the parameters
                'scale', 'shear' and 'rotate' for 'Affine' augmenter, and the parameters 'k' and 'keep_size' for 'Rot90'
                augmenter.
            • An 'Affine' augmentation must be with at most one parameter set not to default (because it's hard to
                control the order of performing them).
            • The parameters of the augmenters most be deterministic. If an augmenter is given with either some not-
                deterministic parameter, or some sub-augmenter with some not-deterministic parameter, the augmenter (or
                the sub-augmenter, respectively) will be considered as an 'Identity' augmenter.

        :param augmenter: An imgaug.augmenters.Augmenter object.

        :return: A list of tuples. One tuple for every sub-augmenters contained in 'augmenter'.
            Each tuple is in the following form: (class_name, parameters) where:
            • 'class_name' is the class name of the the sub-augmenter (as str).
            • 'parameters' is a dictionary that contains any defined parameter with its name as a key.
                Note, not any parameter is defined to be taken in account.
        """

        def get_info(aug: Augmenter) -> Tuple[str, dict]:
            try:
                # in case the given augmentation contains a not deterministic parameter... return an Identity augmenter info
                is_not_determnistic = lambda param: isinstance(param, Choice) or (
                    isinstance(param, Iterable)
                    and not isinstance(param, str)
                    and any((is_not_determnistic(sub_param) for sub_param in param))
                )
                if is_not_determnistic(aug.get_parameters()):
                    return "Identity", dict()

                class_name = aug.__class__.__name__
                parameters = dict()

                if class_name == "Rot90":
                    k = aug.k.value
                    keep_size = aug.keep_size
                    if k == 0:
                        class_name = "Identity"
                    else:
                        parameters["k"] = k
                        parameters["keep_size"] = keep_size

                elif class_name == "Affine":

                    def raise_too_much_parameters_exception():
                        raise Exception(
                            "Affine augmentation must be with at most one parameter set not to default."
                        )

                    is_rotate: bool = aug.rotate.value != 0
                    is_scale: bool = not any(
                        [
                            isinstance(aug.scale, tuple)
                            and aug.scale[0].value == 1
                            and aug.scale[1].value == 1,
                            not isinstance(aug.scale, tuple) and aug.scale.value == 1,
                        ]
                    )
                    is_shear: bool = not any(
                        [
                            isinstance(aug.shear, tuple)
                            and aug.shear[0].value == 0
                            and aug.shear[1].value == 0,
                            not isinstance(aug.shear, tuple) and aug.shear.value == 0,
                        ]
                    )

                    # in case there isn't any affine augmentation (i.e. rotate=0 and scale=1 )
                    if not any([is_rotate, is_scale, is_shear]):
                        class_name = "Identity"

                    # in case there is a rotation parameter instance
                    elif is_rotate:
                        if any([is_scale, is_shear]):
                            raise_too_much_parameters_exception()
                        rotate = aug.rotate.value
                        parameters["rotate"] = rotate

                    # in case there is a scaling parameter instance
                    elif is_scale:
                        if is_shear:
                            raise_too_much_parameters_exception()
                        scale = (
                            dict((("x", aug.scale[0].value), ("y", aug.scale[1].value)))
                            if isinstance(aug.scale, tuple)
                            else dict((("x", aug.scale.value), ("y", aug.scale.value)))
                        )
                        parameters["scale"] = scale

                    # in case there is a shearing parameter instance
                    else:
                        shear = (
                            dict((("x", aug.shear[0].value), ("y", aug.shear[1].value)))
                            if isinstance(aug.shear, tuple)
                            else dict((("x", aug.shear.value), ("y", 0)))
                        )
                        parameters["shear"] = shear

                elif class_name in [
                    "AdditiveGaussianNoise",
                    "GaussianBlur",
                    "GammaContrast",
                ]:
                    class_name = "Identity"

                return class_name, parameters
            except:
                # print('Augmentation information extraction failed (for testing time augmentation, ignore this massage!)')
                return "Identity", dict()

        if isinstance(augmenter, Iterable):
            names = []
            for child in augmenter.get_all_children(flat=True):
                current_name = get_info(child)
                if current_name[0] not in ["Identity", "Sequential"]:
                    names.append(current_name)
            if not names:
                names = [("Identity", dict())]
        else:
            names = [get_info(augmenter)]
        return names

    @staticmethod
    def create_back_augmenter_by_info(augments: List[Tuple[str, dict]]) -> Augmenter:
        """
        Creating an augmenter (as an imgaug.augmenters.Augmenter object) that will apply on an input image (at invoking
        time) all the augmenters that their information is in the given list, in a reverse order.

        .. note::
            • Read carefully the documentation of the current method before eny use or change.
            • Read carefully the documentation of the current class's method 'get_determenistic_augs_info' before eny use or change.
            • Please make sure to modify the documentation respectively to any change in the method.
            • Any change in this method must be taken into account in the current class's method 'get_determenistic_augs_info'.
            • Not every augmenter's parameters are taken into account. For now, the method supports only the parameters
                'scale', 'shear' and 'rotate' for 'Affine' augmenter, and the parameters 'k' and 'keep_size' for 'Rot90'
                augmenter.
            • An 'Affine' augmentation must be with at most one parameter set not to default (because it's hard to
                control the order of performing them).
            • The parameters of the augmenters most be deterministic.

        :param augments: A list of tuples with augmenters information (see the returned value of the current class's
            method 'get_determenistic_augs_info').

        :return: The result Augmenter object
        """

        def create_aug(aug: Tuple[str, dict]) -> Augmenter:
            class_name, parameters = aug
            parameters = deepcopy(parameters)

            if class_name == "Rot90":
                parameters["k"] = (4 - parameters["k"]) % 4

            elif class_name == "Affine":
                if len(parameters) > 1:
                    raise Exception(
                        "Affine augmentation must be with at most one parameter."
                    )

                # in case there is a rotation parameter instance
                if "rotate" in parameters:
                    parameters["rotate"] = -parameters["rotate"]

                # in case there is a scaling parameter instance
                elif "scale" in parameters:
                    parameters["scale"]["x"] = 1 / parameters["scale"]["x"]
                    parameters["scale"]["y"] = 1 / parameters["scale"]["y"]

                # in case there is a shearing parameter instance
                elif "shear" in parameters:
                    parameters["shear"]["x"] = -parameters["shear"]["x"]
                    parameters["shear"]["y"] = -parameters["shear"]["y"]

            return eval(
                f'{class_name}({",".join([f"{p}={parameters[p]}" for p in parameters])})'
            )

        return Sequential([create_aug(aug) for aug in augments[::-1]])

    @staticmethod
    def augment_back_arrays_of_images_by_info(
        array_of_images: np.ndarray,
        augments: List[Tuple[str, dict]],
        debug_mode: bool = False,
    ) -> np.ndarray:
        """
        Apply all the augmenters that their information is in the given list 'augments', in a reverse order, on the
        given array of images 'array_of_images'.

        :param array_of_images: An array of images. Namely, a 4D ndarray that contains 3D images that need to be
            augmented back.
        :param augments: A list of tuples with augmenters information (see the returned value of the current class's
            method 'get_determenistic_augs_info'). Read carefully the documentation of the current class's method
            'create_back_augmenter_by_info'. The argument 'augments' is being send to this method.
        :param debug_mode: If set to False (by default) the function will be multiprocessed, otherwise it won't.

        :return: An array of images containing the augmented images.
        """

        augmenter_back = Augmenters.create_back_augmenter_by_info(augments)

        # splitting the images into batches for multiprocessing purpose
        batches = Augmenters.__create_lists_of_batches(array_of_images)

        return Augmenters.__apply_aug_on_batches(
            augmenter_back, batches, debug_mode=debug_mode
        )

    @staticmethod
    def augment_back_arrays_of_segmentations_by_info(
        array_of_segments: np.ndarray,
        augments: List[Tuple[str, dict]],
        debug_mode: bool = False,
    ) -> np.ndarray:
        """
        Apply all the augmenters that their information is in the given list 'augments', in a reverse order, on the
        given array of segmentations 'array_of_segments'.

        :param array_of_segments: An array of segmentations. Namely, a 4D ndarray that contains 3D segmentations that
            need to be augmented back.
        :param augments: A list of tuples with augmenters information (see the returned value of the current class's
            method 'get_determenistic_augs_info'). Read carefully the documentation of the current class's method
            'create_back_augmenter_by_info'. The argument 'augments' is being send to this method.
        :param debug_mode: If set to False (by default) the function will be multiprocessed, otherwise it won't.

        :return: An array of segmentations containing the augmented segmentations.
        """

        augmenter_back = Augmenters.create_back_augmenter_by_info(augments)

        # splitting the images into batches for multiprocessing purpose
        batches = Augmenters.__create_lists_of_batches(
            array_of_segments=array_of_segments
        )

        return Augmenters.__apply_aug_on_batches(
            augmenter_back, batches, debug_mode=debug_mode
        )

    @staticmethod
    def augment_back_arrays_of_segmentations_by_informations(
        array_of_segments: np.ndarray, list_of_augments: List[List[Tuple[str, dict]]]
    ) -> np.ndarray:
        """
        Apply on each segmentation in the given array 'array_of_segments' the relatively augmenters that their
        information is in the given list 'list_of_augments', in a reverse order.

        :param array_of_segments: An array of segmentations. Namely, a 4D ndarray that contains 3D segmentations that
            need to be augmented back (array_of_segments.shape[0] most be as teh length of 'list_of_augments').
        :param list_of_augments: A list that contains lists of tuples with augmenters information (see the returned
            value of the current class's method 'get_determenistic_augs_info'). Read carefully the documentation of the
            current class's method 'create_back_augmenter_by_info'. The all the arguments in  'list_of_augments' are
            being send to this method.

        :return: An array of segmentations containing the augmented segmentations.
        """

        segments_and_augs = zip(array_of_segments, list_of_augments)

        print("before Pool in aug")
        system("free -h")
        # with Pool(6) as pool:
        reses2 = np.stack(
            list(map(Augmenters.unpack_and_augment_back_segment, segments_and_augs))
        )
        #       reses = []
        #       for d in segments_and_augs:
        #           print("before append")
        #           reses.append(Augmenters.unpack_and_augment_back_segment(d))
        #           print("after append")

        #       print("before stacking")
        #       system("free -h")
        #       print("11", type(reses))
        #       print("12", len(reses), type(reses[0]))
        #       reses2 = reses.pop(0)
        #       reses2 = np.expand_dims(reses2,0)
        #       print("before while")
        #       i = 0
        #       while len(reses) > 0:
        #           system("free -h")
        #           print(f"before collect {i}")
        #           system("free -h")
        #           i += 1
        #           gc.collect()
        #           y = reses.pop(0)
        #           reses2 = np.concatenate([reses2, np.expand_dims(y, 0)])
        #       # res = np.stack(reses2)
        print("Here 6")
        return reses2.astype(array_of_segments.dtype)

    @staticmethod
    def augment_back_segment(
        seg: np.ndarray, augments: List[Tuple[str, dict]]
    ) -> np.ndarray:
        """
        Apply all the augmenters that their information is in the given list 'augments', in a reverse order, on the
        given segmentation 'seg'.

        :param seg: A ndarray representing a segmentation.
        :param augments: A list of tuples with augmenters information (see the returned value of the current class's
            method 'get_determenistic_augs_info'). Read carefully the documentation of the current class's method
            'create_back_augmenter_by_info'. The argument 'augments' is being send to this method.

        :return: The augmented back segmentation
        """

        augmenter_back = Augmenters.create_back_augmenter_by_info(augments)
        print("Here 3")
        seg_map = SegmentationMapsOnImage(arr=seg.astype(np.int32), shape=seg.shape)
        print("Here 4")
        res = augmenter_back(segmentation_maps=seg_map)
        print("Here 5")
        return res.get_arr()

    @staticmethod
    def unpack_and_augment_back_segment(
        segment_and_augments: Tuple[np.ndarray, List[Tuple[str, dict]]]
    ) -> np.ndarray:
        """
        Unpack 'segment_and_augments' to (seg: np.ndarray, augments: List[Tuple[str, dict]]) and:
        Apply all the augmenters that their information is in the given list 'augments', in a reverse order, on the
        given segmentation 'seg'.

        :param segment_and_augments: A tuple in the following order: (seg, augments) where:
            • seg is a ndarray representing a segmentation.
            • augments is a list of tuples with augmenters information (see the returned value of the current class's
              method 'get_determenistic_augs_info'). Read carefully the documentation of the current class's method
              'create_back_augmenter_by_info'. The argument 'augments' is being send to this method.

        :return: The augmented back segmentation
        """
        print("Here 1")
        seg, augments = segment_and_augments
        print("Here 2")
        return Augmenters.augment_back_segment(seg, augments)


def getLargestCC(segmentation):
    labels = label(segmentation, connectivity=1)
    assert labels.max() != 0  # assume at least 1 CC
    largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1
    return largestCC.astype(segmentation.dtype)


def calculate_runtime(t):
    t2 = gmtime(time() - t)
    return f"{t2.tm_hour:02.0f}:{t2.tm_min:02.0f}:{t2.tm_sec:02.0f}"


class LiverPredictor:
    """
    A class for predicting segmentations of the Liver of a given list of CT images through a given fitted model by
    applying several augmentations on the images, summing their predictions and computing the final predictions
    according to a given threshold.
    Namely, at each segmentation result, each voxel will be true at the final segmentation iff it was true at list at
    "some-given-threshold" segmentations of the augmented segmentations.
    """

    def __init__(
        self,
        model_file: str,
        augs: Augmenters = None,
        apply_augs: bool = True,
        n: int = None,
        seed: int = 42,
        model_dim: Tuple[int, int, int] = (128, 128, 48),
        pred_up_resize_order: int = 1,
    ):
        """
        :param model_file: A path to a fitted (keras) model - str. The model has to be such that it receives a 5 shape
            image to predict: (number of patches, number of channels=1, x, y ,z).
        :param augs: A collection of augmenters to apply on the images - Augmenters object (optional - if it's set to
            None a default container of some augmentations will be applied on the input images). If 'apply_augs' is set
            to False, the argument 'augs' is useless. Note, all the augmentations and sub-augmentations in the container
            augs must be with deterministic parameters only.
        :param apply_augs: A boolean. If set to true (by default) the augmentations that are in 'augs' (or in the
            default container in case 'augs' is set to None) will be applied on the images, otherwise augmentations
            won't be applied at all on the input images.
        :param n: int or None. The number of augmentations (chosen randomly) to apply on the images. If either set to
            None (by default) or is bigger than the number of augmentations available, all the augmentations will be
            applied on the input images. If 'apply_augs' is set to False, the argument 'n' is ignored.
        :param seed: The seed for the randomness of choosing 'n' augmentations to apply and for augmentations parameters
            (if seed is set to None, the randomness will be truly "random"). If either 'apply_augs' is set to False or
            'n' is set to None, the argument 'seed' is ignored.
        :param model_dim: A tuple containing the shape of each image that will be sent to the model.
        """

        self.seed = seed
        self.apply_augs = apply_augs

        # extracting the model
        print()
        self.model = load_model(model_file, compile=False)

        # making sure the original image will be also predicted (not only augmentations)
        self.original_also: bool = True
        if augs is None:
            if not apply_augs:
                self.original_also = False
                self.augs: Augmenters = Augmenters([Identity()])
            else:
                self.augs: Augmenters = self.get_predict_time_default_augmenters_container().get_n_random_augmenters(
                    n, seed
                )
        else:
            if not isinstance(augs, Augmenters):
                raise Exception("The augs argument most be an 'Augmenters' object.")
            if not apply_augs:
                self.original_also = False
                self.augs: Augmenters = Augmenters([Identity()])
            else:
                self.augs: Augmenters = augs.get_n_random_augmenters(n, seed)

        self.model_dim = model_dim
        self.pred_up_resize_order = pred_up_resize_order

    @staticmethod
    def get_predict_time_default_augmenters_container() -> Augmenters:
        """
        Creating a default container of augmentations (an Augmenters object) with a bunch of different augmenters, for
        predict-time augmentations. The augmenters are consist of combinations of rotations,
        scales and shears.

        :return: The container (Augmenter object).
        """

        # preparing scale intervals
        optional_scales = (0.85, 0.9, 1, 1.1, 1.15)
        xs = product("x", optional_scales)
        ys = product("y", optional_scales)
        scales = list(product(xs, ys))

        # preparing shear intervals
        optional_shear = (-7.5, -5, 0, 5, 7.5)
        xs = product("x", optional_shear)
        ys = product("y", optional_shear)
        shears = list(product(xs, ys))

        return Augmenters(
            [
                Sequential(
                    [
                        Affine(rotate=rotate),
                        Affine(scale=dict(scale)),
                        Affine(shear=dict(shear)),
                    ]
                )
                for rotate in (0, 15, -15, 7.5, -7.5)
                for scale in scales
                for shear in shears
            ]
        )

    def predict(
        self,
        ct_niftis: List[str],
        liver_gt_niftis: List[str] = None,
        clip_intensities: Tuple[float, float] = (-150, 150),
        num_of_prediction_threshold: int = 3,
        summing_results: bool = True,
        get_ct_in_unet: bool = False,
        clip_original_ct: bool = False,
    ):
        """
        Predicting Liver segmentations of the given list of CT images through self.model, by applying the augmentations
        in self.augs on the images, predicting them and computing the final predictions according to the given threshold.

        :param ct_niftis: A list of nifti files' names of CT scans.
        :param liver_gt_niftis: A list of nifti files' names of Liver-GT segmentations (optional, None by default).
        :param clip_intensities: A tuple (of shape 2) indicating the interesting intensities for predicting, by default (-150,150).
        :param num_of_prediction_threshold: The minimum number of segmentations (from the augmented ones) a voxel has to
            be true there in order to be true at the final segmentation. If 'summing_results' is set to False, the
            argument 'num_of_prediction_threshold' is ignored.
        :param summing_results: If set to True (by default), the result predictions will be summed (superposition),
            otherwise the function will make a logical 'or' between all partial results.
        :param get_ct_in_unet: If set to True the second argument in the returned tuple will be the original image after
            resizing it to the model's shape and resizing it back to the original shape. If set to False (by default)
            the second argument in the returned tuple will be None.
        :param clip_original_ct: If set to False (by default) the first argument in the returned tuple will be the
            original image, otherwise it will be clipped before to 'clip_intensities'.

        .. Note::
            The CT_BL_niftis_paths list and the Liver_GT_niftis list most be of the same length and ordered respectively.
            So should each file in CT_BL_niftis_paths be in the same shape as the respectively file in 'Liver_GT_niftis'.

        :return: Yields for each given CT case in CT_BL_niftis_paths, a tuple in the following form:
            (CT_Canonical, CT_in_Unet, GT_Liver, pred_seg, pred_label, nifti_file) where:
            • CT_Canonical is the original image (clipped to the 'clip_intensities').
            • CT_in_Unet is the original image after resizing it to the model's shape and resizing it back to the original
                shape (if 'get_ct_in_unet' is False, it will be None).
            • GT_Liver is the given GT of the liver (if 'liver_gt_niftis' is None, it will be also None).
            • pred_seg is the final predicted segmentation (after thresholding if 'summing_results' is set to True).
            • pred_label is the final predicted labeled segmentation. The labels can vary from 1 and above. Depends how
                big self.n is (if 'summing_results' is set to False, 'pred_label' will be equal to 'pred_seg').
            • nifti_file is the nifti file (a nibabel.Nifti1Image object) of the original CT case.
        """

        for i, ct_nifti_file_name in enumerate(ct_niftis):
            # todo remove before production
            # for testing purpose
            t = time()

            print("-------> for start")
            if liver_gt_niftis is not None:
                if ct_nifti_file_name == liver_gt_niftis[i]:
                    raise Exception(
                        f"The case path and the liver case path are identical: {ct_nifti_file_name}"
                    )

            # getting all the desired cases
            ct_case, ct_nifti_file = self.load_nifti_data(ct_nifti_file_name)
            # save(as_closest_canonical(Nifti1Image(ct_case, ct_nifti_file.affine)),  ct_nifti_file_name.replace('scan.nii.gz', 'scan_as_closest_canonical.nii.gz'))

            if np.unique(ct_case).size <= 2:
                raise Exception(
                    f"It looks like the case file is in fact a mask file. It has less or equal to 2 unique variables: {ct_nifti_file_name}"
                )
            if clip_original_ct:
                ct_case = np.clip(ct_case, clip_intensities[0], clip_intensities[1])
                working_case = (ct_case - ct_case.min()) / (
                    ct_case.max() - ct_case.min()
                )
            else:
                working_case = np.clip(
                    ct_case, clip_intensities[0], clip_intensities[1]
                )
                working_case = (working_case - working_case.min()) / (
                    working_case.max() - working_case.min()
                )

            print("-------> __predict")

            # predicting every augmentation
            predict_label_case = self.__get_prediction(working_case, summing_results)

            print("------> after __get_prediction")

            # final segmentation prediction
            if summing_results and self.apply_augs:
                predict_seg_case = np.zeros_like(predict_label_case)
                predict_seg_case[predict_label_case < num_of_prediction_threshold] = 0
                predict_seg_case[predict_label_case >= num_of_prediction_threshold] = 1
            else:
                predict_seg_case = predict_label_case.copy()

            # extract the largest connected component from the predicted liver segmentation
            predict_seg_case = getLargestCC(predict_seg_case).astype(
                predict_seg_case.dtype
            )

            # fill holes over 2D slices of the predicted liver segmentations
            predict_seg_case = binary_fill_holes(
                predict_seg_case,
                np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
                .reshape([3, 3, 1])
                .astype(predict_seg_case.dtype),
            )

            gt_liver_case = None
            if liver_gt_niftis is not None:
                gt_liver_case, _ = self.load_nifti_data(liver_gt_niftis[i])
                gt_liver_case = np.clip(gt_liver_case, 0, 1)
                if np.any((gt_liver_case != 0) & (gt_liver_case != 1)):
                    raise Exception(
                        f"It looks like the liver case file is in fact not a mask case. It contains variables different than 0 or 1: {liver_gt_niftis[i]}"
                    )

            ct_in_unet_case = None
            if get_ct_in_unet:
                ct_in_unet_case = resize(working_case, self.model_dim)
                ct_in_unet_case = resize(ct_in_unet_case, ct_case.shape)

            # collecting the final results
            res = (
                ct_case,
                ct_in_unet_case,
                gt_liver_case,
                predict_seg_case,
                predict_label_case,
                ct_nifti_file,
            )

            print(
                f'Finished case "{basename(dirname(ct_nifti_file_name))}" in {calculate_runtime(t)} (hh/mm/ss)'
            )

            yield res

    def __get_prediction(
        self, working_case: np.ndarray, summing_results: bool
    ) -> np.ndarray:
        """
        Getting the prediction according to self.model.

        :param working_case: The CT case (ndarray).
        :param summing_results: If set to True (by default), the result predictions will be summed (superposition),
            otherwise the function will make a logical 'or' between all partial results.

        :return: The prediction result (ndarray).
        """

        print("------> Starting __get_prediction")

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            auged_cases, aug_infos = zip(
                *self.augs.apply_all_on_image_multiprocessed(
                    working_case, original_also=self.original_also
                )
            )

        print("------> after first with")

        # resizing the augmented images to the models shape
        # with Pool(6) as pool:
        auged_cases = np.stack(
            list(map(partial(resize, output_shape=self.model_dim), auged_cases))
        )

        print("------> after second with")

        # expanding the augmented images dimensions to the dimensions of the model
        auged_cases = np.expand_dims(auged_cases, axis=1)

        with tf.device("GPU:0"):
            # batch_size = gpu_mem_in_use_in_MB // 2000 - 1
            n_images = auged_cases.shape[0]

            batches = []
            num_batches = n_images // 4
            i = 0

            print("auged_cases.shape:", auged_cases.shape)
            print("n_images:", n_images)
            print("num_batches:", num_batches)

            for batch in np.array_split(auged_cases, n_images):
                if batch.size != 0:
                    batches.append(self.model.predict(batch))
                    gc.collect()
                    i = i + 1
            # preds = np.concatenate([self.model.predict(batch_of_auged_cases) for batch_of_auged_cases in
            #                         np.array_split(auged_cases, n_images // 4) if
            #                         batch_of_auged_cases.size != 0])
            preds = np.concatenate(batches)
            preds[preds < 0.5] = 0
            preds[preds >= 0.5] = 1

        print("------> after model.predict")

        # extract the largest connected components
        try:
            # with Pool(6) as pool:
            preds = list(map(getLargestCC, (p[0] for p in preds)))
        except:
            print("???")

        print("------> with Pool()")

        # resizing back the predictions to the original shape
        # with Pool(6) as pool:
        preds = np.stack(
            list(
                map(
                    partial(
                        resize,
                        output_shape=working_case.shape,
                        anti_aliasing=False,
                        order=self.pred_up_resize_order,
                    ),
                    preds,
                )
            )
        )
        preds = (preds > 0.5).astype(preds.dtype)

        print("------> after with Pool")
        # augmenting back the augmented predictions
        preds = self.augs.augment_back_arrays_of_segmentations_by_informations(
            preds, aug_infos
        )
        print("------> after augs")

        if summing_results:
            final_pred = preds.sum(axis=0)
        else:
            final_pred = np.any(preds, axis=0)

        print("------> At __get_prediction end")

        return final_pred

    @staticmethod
    def load_nifti_data(nifti_file_name: str):
        """
        Loading data from a nifti file.

        :param nifti_file_name: The path to the desired nifti file.

        :return: A tuple in the following form: (data, file), where:
        • data is a ndarray containing the loaded data.
        • file is the file object that was loaded.
        """

        # loading nifti file
        nifti_file = load(nifti_file_name)
        nifti_file = as_closest_canonical(nifti_file)

        # extracting the data of the file
        data = nifti_file.get_fdata().astype(np.float32)

        return data, nifti_file


def _model_fullpath(model_filename):
    return os.path.join(_MODELS_FOLDER, model_filename)


def liver_segmentation(**kwargs):
    ct_file = kwargs.get("input_image")
    output_segmentation_file = kwargs.get("output_image")
    save_nifti = kwargs.get("save_nii", False)

    liver_predictor = LiverPredictor(
        _model_fullpath(settings.MODELS.LIVER_SEGMENTATION), n=30, apply_augs=True
    )
    pred_results = liver_predictor.predict(
        [ct_file],
        None,
        get_ct_in_unet=False,
        clip_original_ct=True,
        num_of_prediction_threshold=settings.LIVER_TRESHOLD,
    )

    result_files = []
    for i, (
        ct_case,
        _,
        gt_liver_case,
        predict_case,
        predict_label_case,
        ct_nifti_file,
    ) in enumerate(pred_results):
        if save_nifti:
            nib.save(
                nib.Nifti1Image(predict_case.astype(np.float32), ct_nifti_file.affine),
                output_segmentation_file,
            )
        else:
            np.save(output_segmentation_file, predict_case.astype(bool))
        nib.save(
            nib.Nifti1Image(predict_case.astype(np.float32), ct_nifti_file.affine),
            output_segmentation_file.replace(".npy", ".nii.gz"),
        )


def _get_attachment_filename(response):
    return re.findall("filename=(.+)", response.headers["content-disposition"])[0]


def task(**kwargs):
    task_id = kwargs.get("task_id")

    image_file_url = (
        f"{TASKS_SERVER_URL}/api/task/{task_id}/node/cropping/file/cropped.nii.gz"
    )
    new_files_url = f"{TASKS_SERVER_URL}/api/task/{task_id}/node/liver/file"

    image_file_response = requests.get(image_file_url)
    image_filename = _get_attachment_filename(image_file_response)
    tempdir = tempfile.TemporaryDirectory()

    image_local_filepath = os.path.join(tempdir.name, image_filename)
    image_file = open(image_local_filepath, "wb")
    image_file.write(image_file_response.content)
    image_file.close()

    output_image_filename = os.path.join(
        tempdir.name, f"liver-{image_filename.replace('.nii.gz', '.npy')}"
    )

    # TODO
    # In cases the image that's in 'image_local_filepath' has some weird shape for some reason
    # (x, y, 1) or similar, do not continue to processing!!
    liver_segmentation(
        input_image=image_local_filepath, output_image=output_image_filename
    )

    liver_segmentation_file = open(output_image_filename, "rb")
    liver_segmentation_nii_file = open(
        output_image_filename.replace(".npy", ".nii.gz"), "rb"
    )

    res = requests.put(
        new_files_url,
        files={
            "filename": (None, "liver.npy"),
            "file": ("file", liver_segmentation_file),
        },
    )

    res = requests.put(
        new_files_url,
        files={
            "filename": (None, "liver.nii.gz"),
            "file": ("file", liver_segmentation_nii_file),
        },
    )

    liver_segmentation_file.close()
    liver_segmentation_nii_file.close()

    tempdir.cleanup()