Loader a process dataset
unknown
python
4 years ago
12 kB
7
Indexable
import torch
from load_images import ImagesLoader
import concurrent.futures
import math
import os
from PIL import Image
import numpy
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
class ImagesLoader:
def __init__(self, path,in_ress,out_ress):
self.path = path
self.in_ress = in_ress
self.out_ress = out_ress
self.channels = 3
self.files = os.listdir(path)
self.imgs_count = len(self.files)
self.images = numpy.zeros((self.imgs_count, self.channels, self.in_ress, self.in_ress), dtype=numpy.uint8)
self.labels = numpy.zeros((self.imgs_count, self.channels, self.out_ress, self.out_ress), dtype=numpy.uint8)
def load_images(self):
processes_count = multiprocessing.cpu_count()
if self.imgs_count < processes_count:
raise Exception("The number of images in the dataset is less than the number of CPU cores,"
" please increase the count of training images to at least {}".format(processes_count))
num_for_process = math.floor(int(self.imgs_count) / processes_count)
stop_start_ids = [[0 for _ in range(2)] for _ in range(processes_count)]
for z in range(processes_count):
if z == 0:
stop_start_ids[z][1] = num_for_process
if z == processes_count -1:
stop_start_ids[z][0] = stop_start_ids[z - 1][1]
stop_start_ids[z][1] = self.imgs_count
continue
stop_start_ids[z][0] = stop_start_ids[z - 1][1]
stop_start_ids[z][1] = stop_start_ids[z][0] + num_for_process
with ThreadPoolExecutor(max_workers= processes_count) as executor:
results = [None] * processes_count
for x in range(processes_count):
results[x] = executor.submit(self.run_process, stop_start_ids[x])
return self.images, self.labels
def run_process(self, start_stop):
counter = 0
for i in range(start_stop[0], start_stop[1]):
y = Image.open(os.path.join(self.path, self.files[i])).convert("RGB")
#x = ImageOps.exif_transpose(x)
if y.size[0] > y.size[1]:
new_size = (y.size[0] / y.size[1]) * self.out_ress
y = y.resize((math.floor(new_size), self.out_ress),Image.BICUBIC)
y = y.crop((math.floor((y.size[0] - self.out_ress) / 2),
0, math.floor((y.size[0] - self.out_ress) / 2) + self.out_ress, self.out_ress))
elif y.size[0] == y.size[1]:
y = y.resize((self.out_ress, self.out_ress),Image.BICUBIC)
else:
new_size = (y.size[1] / y.size[0]) * self.out_ress
y = y.resize((self.out_ress, math.floor(new_size)),Image.BICUBIC)
y = y.crop((math.floor((y.size[0] - self.out_ress) / 2),
0, math.floor((y.size[0] - self.out_ress) / 2) + self.out_ress, self.out_ress))
y_np = numpy.array(y)
y_np = numpy.moveaxis(y_np,-1,0)
self.labels[i] = y_np
x = y.resize((self.in_ress,self.in_ress),Image.BICUBIC)
x_np = numpy.array(x)
x_np = numpy.moveaxis(x_np,-1,0)
self.images[i] = x_np
counter+= 1
class Process_dataset:
def __init__(self, in_ress, out_ress, training_path, aug_count, testing_path= False, validate_path= False):
self.in_ress = in_ress
self.out_ress = out_ress
self.training_path = training_path
self.testing_path = testing_path
self.validate_path = validate_path
self.aug_count = aug_count
self.training_images = []
self.training_labels = []
self.validate_images = []
self.validate_labels = []
self.testing_images = []
self.testing_labels = []
if self.aug_count < 0:
raise Exception("The number of augments aug_count can not be negative")
train_loader = ImagesLoader(self.training_path, self.in_ress, self.out_ress)
training_images_raw, training_labels_raw = train_loader.load_images()
if self.testing_path:
test_loader = ImagesLoader(self.testing_path, self.in_ress, self.out_ress)
testing_images, testing_labels = test_loader.load_images()
self.testing_images.append(testing_images)
self.testing_labels.append(testing_labels)
self.testing_count = len(testing_images)
print("Testing images count - {}".format(self.testing_count))
if self.validate_path:
validate_loader = ImagesLoader(self.validate_path, self.in_ress, self.out_ress)
validate_images, validate_labels = validate_loader.load_images()
self.validate_images.append(validate_images)
self.validate_labels.append(validate_labels)
self.validate_count = len(validate_images)
print("Validate images count - {}".format(self.validate_count))
self.training_images.append(training_images_raw)
self.training_labels.append(training_labels_raw)
if self.aug_count !=0:
self.training_images.append(training_images_raw)
self.training_labels.append(training_labels_raw)
images_aug,labels_aug = self._auqumentation(self.training_images, self.training_labels, self.aug_count)
self.training_images.append(images_aug)
self.training_labels.append(labels_aug)
self.training_count = len(training_images_raw) * 2 + len(images_aug)
print("Training images count - {}".format(self.training_count))
print("\n\n")
else:
self.training_count = len(training_images_raw)
print("Training images count - {}".format(self.training_count))
print("\n\n")
def get_training_batch(self, batch_size):
return self.get_batch(self.training_images, self.training_labels, batch_size)
def get_training_count(self):
return self.training_count
def get_testing_batch(self, batch_size):
if self.testing_path:
return self.get_batch(self.testing_images, self.testing_labels, batch_size, training=False)
else:
raise Exception("No testing data... Please specify training path folder in Process_dataset class")
def get_testing_count(self):
if self.testing_path:
return self.testing_count
else:
raise Exception("No testing data... Please specify testing path folder in Process_dataset class")
def get_validate_batch(self, batch_size):
if self.validate_path:
return self.get_batch(self.validate_images, self.validate_labels, batch_size, training=False)
else:
raise Exception("No validate data... Please specify validate path folder in Process_dataset class")
def get_validate_count(self):
if self.validate_path:
return self.validate_count
else:
raise Exception("No validate data... Please specify validate path folder in Process_dataset class")
def process_augumentation(self, image, label):
image_aug = numpy.array((image), dtype=numpy.uint8)
label_aug = numpy.array((label), dtype=numpy.uint8)
angle_max = 25
angle = self._rnd(-angle_max, angle_max)
image_in = Image.fromarray(numpy.moveaxis(image_aug, 0, 2), 'RGB')
mask_in = Image.fromarray(numpy.moveaxis(label_aug, 0, 2), 'RGB')
image_aug = image_in.rotate(angle,resample=Image.BICUBIC)
mask_aug = mask_in.rotate(angle,resample=Image.BICUBIC)
image_aug = numpy.array(image_aug)
mask_aug = numpy.array(mask_aug)
image_aug = numpy.swapaxes(image_aug, 0, 2)
mask_aug = numpy.swapaxes(mask_aug, 0, 2)
return image_aug, mask_aug
def _auqumentation(self, images, labels, aug_count):
count = len(images[0])
counter = 0
images_aug = numpy.zeros((count * aug_count, 3, self.in_ress, self.in_ress), dtype=numpy.uint8)
labels_aug = numpy.zeros((count * aug_count, 3, self.out_ress, self.out_ress), dtype=numpy.uint8)
print(images_aug.shape)
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
results = [None] * count * aug_count
for x in range(count * aug_count):
results[x] = executor.submit(self.process_augumentation,images[0][counter],labels[0][counter])
counter+=1
if counter == count:
counter = 0
counter = 0
for f in concurrent.futures.as_completed(results):
images_aug[counter], labels_aug[counter] = f.result()[0], f.result()[1]
counter += 1
return images_aug, labels_aug
def get_batch(self,images, labels, batch_size, training = True):
result_x = torch.zeros((batch_size, 3, self.in_ress, self.in_ress)).float()
result_y = torch.zeros((batch_size, 3, self.out_ress, self.out_ress)).float()
with ThreadPoolExecutor(max_workers=batch_size) as executor:
results = [None] * batch_size
for x in range(batch_size):
results[x] = executor.submit(self.process_batch,images,labels,training)
counter = 0
for f in concurrent.futures.as_completed(results):
result_x[counter], result_y[counter] = f.result()[0], f.result()[1]
counter += 1
return result_x, result_y
def process_batch(self,images,labels, training):
group_idx = numpy.random.randint(len(images))
image_idx = numpy.random.randint(len(images[group_idx]))
image_np = numpy.array(images[group_idx][image_idx])/ 255.0
label_np = numpy.array(labels[group_idx][image_idx])/ 255.0
if not training:
result_x = torch.from_numpy(image_np).float()
result_y = torch.from_numpy(label_np).float()
return result_x, result_y
elif self.aug_count == 0:
result_x = torch.from_numpy(image_np).float()
result_y = torch.from_numpy(label_np).float()
return result_x, result_y
if group_idx == 1:
image_np, mask_np = self._augmentation_flip(image_np, label_np)
else:
image_np, mask_np = self._augmentation_flip(image_np, label_np)
image_np, mask_np = self._augmentation_noise(image_np, mask_np)
result_x = torch.from_numpy(image_np).float()
result_y = torch.from_numpy(mask_np).float()
return result_x, result_y
def _augmentation_flip(self, image_np,label_np):
if self._rnd(0, 1) < 0.5:
aug_img = numpy.flip(image_np, 1)
aug_label = numpy.flip(label_np, 1)
else:
aug_img = numpy.flip(image_np, 2)
aug_label = numpy.flip(label_np, 2)
return aug_img.copy(), aug_label.copy()
def _augmentation_noise(self, image_np,label_np):
brightness = self._rnd(-0.30, 0.25)
contrast = self._rnd(0.6, 1.1)
img_result = image_np + brightness
img_result = 0.5 + contrast * (img_result - 0.5)
label_result = label_np + brightness
label_result = 0.5 + contrast * (label_result - 0.5)
return numpy.clip(img_result, 0.0, 1.0),numpy.clip(label_result, 0.0, 1.0)
def _rnd(self, min_value, max_value):
return (max_value - min_value) * numpy.random.rand() + min_valueEditor is loading...