Loader a process dataset

mail@pastecode.io avatar
unknown
python
2 years ago
12 kB
3
Indexable
Never
import torch
from load_images import ImagesLoader
import concurrent.futures
import math
import os
from PIL import Image
import numpy
from concurrent.futures import ThreadPoolExecutor
import multiprocessing


class ImagesLoader:
    def __init__(self, path,in_ress,out_ress):
        self.path       = path
        self.in_ress    = in_ress
        self.out_ress   = out_ress
        self.channels   = 3

        self.files      = os.listdir(path)
        self.imgs_count = len(self.files)
        self.images     = numpy.zeros((self.imgs_count, self.channels, self.in_ress, self.in_ress), dtype=numpy.uint8)
        self.labels     = numpy.zeros((self.imgs_count, self.channels, self.out_ress, self.out_ress), dtype=numpy.uint8)


    def load_images(self):

        processes_count = multiprocessing.cpu_count()
        if self.imgs_count < processes_count:
            raise Exception("The number of images in the dataset is less than the number of CPU cores,"
                            " please increase the count of training images to at least {}".format(processes_count))

        num_for_process = math.floor(int(self.imgs_count) / processes_count)

        stop_start_ids = [[0 for _ in range(2)] for _ in range(processes_count)]

        for z in range(processes_count):
            if z == 0:
                stop_start_ids[z][1] = num_for_process

            if z == processes_count -1:
                stop_start_ids[z][0] = stop_start_ids[z - 1][1]
                stop_start_ids[z][1] = self.imgs_count
                continue

            stop_start_ids[z][0] = stop_start_ids[z - 1][1]
            stop_start_ids[z][1] = stop_start_ids[z][0] + num_for_process


        with ThreadPoolExecutor(max_workers= processes_count) as executor:
            results = [None] * processes_count
            for x in range(processes_count):
                results[x] = executor.submit(self.run_process, stop_start_ids[x])


        return self.images, self.labels


    def run_process(self, start_stop):

        counter = 0
        for i in range(start_stop[0], start_stop[1]):
            y = Image.open(os.path.join(self.path, self.files[i])).convert("RGB")
            #x = ImageOps.exif_transpose(x)

            if y.size[0] > y.size[1]:
                new_size = (y.size[0] / y.size[1]) * self.out_ress
                y = y.resize((math.floor(new_size), self.out_ress),Image.BICUBIC)

                y = y.crop((math.floor((y.size[0] - self.out_ress) / 2),
                            0, math.floor((y.size[0] - self.out_ress) / 2) + self.out_ress, self.out_ress))

            elif y.size[0] == y.size[1]:
                y = y.resize((self.out_ress, self.out_ress),Image.BICUBIC)

            else:
                new_size = (y.size[1] / y.size[0]) * self.out_ress
                y = y.resize((self.out_ress, math.floor(new_size)),Image.BICUBIC)

                y = y.crop((math.floor((y.size[0] - self.out_ress) / 2),
                            0, math.floor((y.size[0] - self.out_ress) / 2) + self.out_ress, self.out_ress))

            y_np = numpy.array(y)
            y_np = numpy.moveaxis(y_np,-1,0)

            self.labels[i]  = y_np

            x = y.resize((self.in_ress,self.in_ress),Image.BICUBIC)
            x_np = numpy.array(x)
            x_np = numpy.moveaxis(x_np,-1,0)

            self.images[i]  = x_np

            counter+= 1



class Process_dataset:
    def __init__(self, in_ress, out_ress, training_path, aug_count, testing_path= False, validate_path= False):
        self.in_ress           = in_ress
        self.out_ress          = out_ress
        self.training_path     = training_path
        self.testing_path      = testing_path
        self.validate_path     = validate_path
        self.aug_count         = aug_count

        self.training_images   = []
        self.training_labels   = []
        self.validate_images   = []
        self.validate_labels   = []
        self.testing_images    = []
        self.testing_labels    = []

        if self.aug_count < 0:
            raise Exception("The number of augments aug_count can not be negative")

        train_loader = ImagesLoader(self.training_path, self.in_ress, self.out_ress)
        training_images_raw, training_labels_raw = train_loader.load_images()

        if self.testing_path:
            test_loader = ImagesLoader(self.testing_path, self.in_ress, self.out_ress)
            testing_images, testing_labels = test_loader.load_images()

            self.testing_images.append(testing_images)
            self.testing_labels.append(testing_labels)

            self.testing_count = len(testing_images)
            print("Testing images count - {}".format(self.testing_count))

        if self.validate_path:
            validate_loader = ImagesLoader(self.validate_path, self.in_ress, self.out_ress)
            validate_images, validate_labels = validate_loader.load_images()

            self.validate_images.append(validate_images)
            self.validate_labels.append(validate_labels)

            self.validate_count = len(validate_images)
            print("Validate images count - {}".format(self.validate_count))


        self.training_images.append(training_images_raw)
        self.training_labels.append(training_labels_raw)

        if self.aug_count !=0:
            self.training_images.append(training_images_raw)
            self.training_labels.append(training_labels_raw)

            images_aug,labels_aug = self._auqumentation(self.training_images, self.training_labels, self.aug_count)

            self.training_images.append(images_aug)
            self.training_labels.append(labels_aug)

            self.training_count = len(training_images_raw) * 2 + len(images_aug)
            print("Training images count - {}".format(self.training_count))
            print("\n\n")

        else:
            self.training_count = len(training_images_raw)
        print("Training images count - {}".format(self.training_count))
        print("\n\n")


    def get_training_batch(self, batch_size):
        return self.get_batch(self.training_images, self.training_labels, batch_size)

    def get_training_count(self):
        return self.training_count

    def get_testing_batch(self, batch_size):
        if self.testing_path:
            return self.get_batch(self.testing_images, self.testing_labels, batch_size, training=False)
        else:
            raise Exception("No testing data... Please specify training path folder in Process_dataset class")

    def get_testing_count(self):
        if self.testing_path:
            return self.testing_count
        else:
            raise Exception("No testing data... Please specify testing path folder in Process_dataset class")

    def get_validate_batch(self, batch_size):
        if self.validate_path:
            return self.get_batch(self.validate_images, self.validate_labels, batch_size, training=False)
        else:
            raise Exception("No validate data... Please specify validate path folder in Process_dataset class")

    def get_validate_count(self):
        if self.validate_path:
            return self.validate_count
        else:
            raise Exception("No validate data... Please specify validate path folder in Process_dataset class")


    def process_augumentation(self, image, label):
        image_aug = numpy.array((image), dtype=numpy.uint8)
        label_aug = numpy.array((label), dtype=numpy.uint8)

        angle_max = 25
        angle     = self._rnd(-angle_max, angle_max)

        image_in  = Image.fromarray(numpy.moveaxis(image_aug, 0, 2), 'RGB')
        mask_in   = Image.fromarray(numpy.moveaxis(label_aug, 0, 2), 'RGB')

        image_aug = image_in.rotate(angle,resample=Image.BICUBIC)
        mask_aug  = mask_in.rotate(angle,resample=Image.BICUBIC)

        image_aug = numpy.array(image_aug)
        mask_aug  = numpy.array(mask_aug)

        image_aug = numpy.swapaxes(image_aug, 0, 2)
        mask_aug  = numpy.swapaxes(mask_aug, 0, 2)

        return image_aug, mask_aug


    def _auqumentation(self, images, labels, aug_count):
        count      = len(images[0])
        counter    = 0
        images_aug = numpy.zeros((count * aug_count, 3, self.in_ress, self.in_ress), dtype=numpy.uint8)
        labels_aug = numpy.zeros((count * aug_count, 3, self.out_ress, self.out_ress), dtype=numpy.uint8)
        print(images_aug.shape)

        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
            results = [None] * count * aug_count
            for x in range(count * aug_count):
                results[x] = executor.submit(self.process_augumentation,images[0][counter],labels[0][counter])
                counter+=1
                if counter == count:
                    counter = 0

            counter = 0
            for f in concurrent.futures.as_completed(results):
                images_aug[counter], labels_aug[counter] = f.result()[0], f.result()[1]
                counter += 1

        return images_aug, labels_aug


    def get_batch(self,images, labels, batch_size, training = True):
        result_x = torch.zeros((batch_size, 3, self.in_ress, self.in_ress)).float()
        result_y = torch.zeros((batch_size, 3, self.out_ress, self.out_ress)).float()

        with ThreadPoolExecutor(max_workers=batch_size) as executor:
            results = [None] * batch_size
            for x in range(batch_size):
                results[x] = executor.submit(self.process_batch,images,labels,training)

            counter = 0
            for f in concurrent.futures.as_completed(results):
                result_x[counter], result_y[counter] = f.result()[0], f.result()[1]
                counter += 1

        return result_x, result_y


    def process_batch(self,images,labels, training):
        group_idx = numpy.random.randint(len(images))
        image_idx = numpy.random.randint(len(images[group_idx]))

        image_np  = numpy.array(images[group_idx][image_idx])/ 255.0
        label_np  = numpy.array(labels[group_idx][image_idx])/ 255.0

        if not training:
            result_x = torch.from_numpy(image_np).float()
            result_y = torch.from_numpy(label_np).float()

            return result_x, result_y

        elif self.aug_count == 0:
            result_x = torch.from_numpy(image_np).float()
            result_y = torch.from_numpy(label_np).float()

            return result_x, result_y

        if group_idx == 1:
            image_np, mask_np = self._augmentation_flip(image_np, label_np)

        else:
            image_np, mask_np = self._augmentation_flip(image_np, label_np)
            image_np, mask_np = self._augmentation_noise(image_np, mask_np)

        result_x = torch.from_numpy(image_np).float()
        result_y = torch.from_numpy(mask_np).float()

        return result_x, result_y

    def _augmentation_flip(self, image_np,label_np):
        if self._rnd(0, 1) < 0.5:
            aug_img   = numpy.flip(image_np, 1)
            aug_label = numpy.flip(label_np, 1)
        else:
            aug_img   = numpy.flip(image_np, 2)
            aug_label = numpy.flip(label_np, 2)

        return aug_img.copy(), aug_label.copy()

    def _augmentation_noise(self, image_np,label_np):
        brightness = self._rnd(-0.30, 0.25)
        contrast   = self._rnd(0.6, 1.1)

        img_result = image_np + brightness
        img_result = 0.5 + contrast * (img_result - 0.5)

        label_result = label_np + brightness
        label_result = 0.5 + contrast * (label_result - 0.5)


        return numpy.clip(img_result, 0.0, 1.0),numpy.clip(label_result, 0.0, 1.0)


    def _rnd(self, min_value, max_value):
        return (max_value - min_value) * numpy.random.rand() + min_value