Loader a process dataset
unknown
python
3 years ago
12 kB
6
Indexable
import torch from load_images import ImagesLoader import concurrent.futures import math import os from PIL import Image import numpy from concurrent.futures import ThreadPoolExecutor import multiprocessing class ImagesLoader: def __init__(self, path,in_ress,out_ress): self.path = path self.in_ress = in_ress self.out_ress = out_ress self.channels = 3 self.files = os.listdir(path) self.imgs_count = len(self.files) self.images = numpy.zeros((self.imgs_count, self.channels, self.in_ress, self.in_ress), dtype=numpy.uint8) self.labels = numpy.zeros((self.imgs_count, self.channels, self.out_ress, self.out_ress), dtype=numpy.uint8) def load_images(self): processes_count = multiprocessing.cpu_count() if self.imgs_count < processes_count: raise Exception("The number of images in the dataset is less than the number of CPU cores," " please increase the count of training images to at least {}".format(processes_count)) num_for_process = math.floor(int(self.imgs_count) / processes_count) stop_start_ids = [[0 for _ in range(2)] for _ in range(processes_count)] for z in range(processes_count): if z == 0: stop_start_ids[z][1] = num_for_process if z == processes_count -1: stop_start_ids[z][0] = stop_start_ids[z - 1][1] stop_start_ids[z][1] = self.imgs_count continue stop_start_ids[z][0] = stop_start_ids[z - 1][1] stop_start_ids[z][1] = stop_start_ids[z][0] + num_for_process with ThreadPoolExecutor(max_workers= processes_count) as executor: results = [None] * processes_count for x in range(processes_count): results[x] = executor.submit(self.run_process, stop_start_ids[x]) return self.images, self.labels def run_process(self, start_stop): counter = 0 for i in range(start_stop[0], start_stop[1]): y = Image.open(os.path.join(self.path, self.files[i])).convert("RGB") #x = ImageOps.exif_transpose(x) if y.size[0] > y.size[1]: new_size = (y.size[0] / y.size[1]) * self.out_ress y = y.resize((math.floor(new_size), self.out_ress),Image.BICUBIC) y = y.crop((math.floor((y.size[0] - self.out_ress) / 2), 0, math.floor((y.size[0] - self.out_ress) / 2) + self.out_ress, self.out_ress)) elif y.size[0] == y.size[1]: y = y.resize((self.out_ress, self.out_ress),Image.BICUBIC) else: new_size = (y.size[1] / y.size[0]) * self.out_ress y = y.resize((self.out_ress, math.floor(new_size)),Image.BICUBIC) y = y.crop((math.floor((y.size[0] - self.out_ress) / 2), 0, math.floor((y.size[0] - self.out_ress) / 2) + self.out_ress, self.out_ress)) y_np = numpy.array(y) y_np = numpy.moveaxis(y_np,-1,0) self.labels[i] = y_np x = y.resize((self.in_ress,self.in_ress),Image.BICUBIC) x_np = numpy.array(x) x_np = numpy.moveaxis(x_np,-1,0) self.images[i] = x_np counter+= 1 class Process_dataset: def __init__(self, in_ress, out_ress, training_path, aug_count, testing_path= False, validate_path= False): self.in_ress = in_ress self.out_ress = out_ress self.training_path = training_path self.testing_path = testing_path self.validate_path = validate_path self.aug_count = aug_count self.training_images = [] self.training_labels = [] self.validate_images = [] self.validate_labels = [] self.testing_images = [] self.testing_labels = [] if self.aug_count < 0: raise Exception("The number of augments aug_count can not be negative") train_loader = ImagesLoader(self.training_path, self.in_ress, self.out_ress) training_images_raw, training_labels_raw = train_loader.load_images() if self.testing_path: test_loader = ImagesLoader(self.testing_path, self.in_ress, self.out_ress) testing_images, testing_labels = test_loader.load_images() self.testing_images.append(testing_images) self.testing_labels.append(testing_labels) self.testing_count = len(testing_images) print("Testing images count - {}".format(self.testing_count)) if self.validate_path: validate_loader = ImagesLoader(self.validate_path, self.in_ress, self.out_ress) validate_images, validate_labels = validate_loader.load_images() self.validate_images.append(validate_images) self.validate_labels.append(validate_labels) self.validate_count = len(validate_images) print("Validate images count - {}".format(self.validate_count)) self.training_images.append(training_images_raw) self.training_labels.append(training_labels_raw) if self.aug_count !=0: self.training_images.append(training_images_raw) self.training_labels.append(training_labels_raw) images_aug,labels_aug = self._auqumentation(self.training_images, self.training_labels, self.aug_count) self.training_images.append(images_aug) self.training_labels.append(labels_aug) self.training_count = len(training_images_raw) * 2 + len(images_aug) print("Training images count - {}".format(self.training_count)) print("\n\n") else: self.training_count = len(training_images_raw) print("Training images count - {}".format(self.training_count)) print("\n\n") def get_training_batch(self, batch_size): return self.get_batch(self.training_images, self.training_labels, batch_size) def get_training_count(self): return self.training_count def get_testing_batch(self, batch_size): if self.testing_path: return self.get_batch(self.testing_images, self.testing_labels, batch_size, training=False) else: raise Exception("No testing data... Please specify training path folder in Process_dataset class") def get_testing_count(self): if self.testing_path: return self.testing_count else: raise Exception("No testing data... Please specify testing path folder in Process_dataset class") def get_validate_batch(self, batch_size): if self.validate_path: return self.get_batch(self.validate_images, self.validate_labels, batch_size, training=False) else: raise Exception("No validate data... Please specify validate path folder in Process_dataset class") def get_validate_count(self): if self.validate_path: return self.validate_count else: raise Exception("No validate data... Please specify validate path folder in Process_dataset class") def process_augumentation(self, image, label): image_aug = numpy.array((image), dtype=numpy.uint8) label_aug = numpy.array((label), dtype=numpy.uint8) angle_max = 25 angle = self._rnd(-angle_max, angle_max) image_in = Image.fromarray(numpy.moveaxis(image_aug, 0, 2), 'RGB') mask_in = Image.fromarray(numpy.moveaxis(label_aug, 0, 2), 'RGB') image_aug = image_in.rotate(angle,resample=Image.BICUBIC) mask_aug = mask_in.rotate(angle,resample=Image.BICUBIC) image_aug = numpy.array(image_aug) mask_aug = numpy.array(mask_aug) image_aug = numpy.swapaxes(image_aug, 0, 2) mask_aug = numpy.swapaxes(mask_aug, 0, 2) return image_aug, mask_aug def _auqumentation(self, images, labels, aug_count): count = len(images[0]) counter = 0 images_aug = numpy.zeros((count * aug_count, 3, self.in_ress, self.in_ress), dtype=numpy.uint8) labels_aug = numpy.zeros((count * aug_count, 3, self.out_ress, self.out_ress), dtype=numpy.uint8) print(images_aug.shape) with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: results = [None] * count * aug_count for x in range(count * aug_count): results[x] = executor.submit(self.process_augumentation,images[0][counter],labels[0][counter]) counter+=1 if counter == count: counter = 0 counter = 0 for f in concurrent.futures.as_completed(results): images_aug[counter], labels_aug[counter] = f.result()[0], f.result()[1] counter += 1 return images_aug, labels_aug def get_batch(self,images, labels, batch_size, training = True): result_x = torch.zeros((batch_size, 3, self.in_ress, self.in_ress)).float() result_y = torch.zeros((batch_size, 3, self.out_ress, self.out_ress)).float() with ThreadPoolExecutor(max_workers=batch_size) as executor: results = [None] * batch_size for x in range(batch_size): results[x] = executor.submit(self.process_batch,images,labels,training) counter = 0 for f in concurrent.futures.as_completed(results): result_x[counter], result_y[counter] = f.result()[0], f.result()[1] counter += 1 return result_x, result_y def process_batch(self,images,labels, training): group_idx = numpy.random.randint(len(images)) image_idx = numpy.random.randint(len(images[group_idx])) image_np = numpy.array(images[group_idx][image_idx])/ 255.0 label_np = numpy.array(labels[group_idx][image_idx])/ 255.0 if not training: result_x = torch.from_numpy(image_np).float() result_y = torch.from_numpy(label_np).float() return result_x, result_y elif self.aug_count == 0: result_x = torch.from_numpy(image_np).float() result_y = torch.from_numpy(label_np).float() return result_x, result_y if group_idx == 1: image_np, mask_np = self._augmentation_flip(image_np, label_np) else: image_np, mask_np = self._augmentation_flip(image_np, label_np) image_np, mask_np = self._augmentation_noise(image_np, mask_np) result_x = torch.from_numpy(image_np).float() result_y = torch.from_numpy(mask_np).float() return result_x, result_y def _augmentation_flip(self, image_np,label_np): if self._rnd(0, 1) < 0.5: aug_img = numpy.flip(image_np, 1) aug_label = numpy.flip(label_np, 1) else: aug_img = numpy.flip(image_np, 2) aug_label = numpy.flip(label_np, 2) return aug_img.copy(), aug_label.copy() def _augmentation_noise(self, image_np,label_np): brightness = self._rnd(-0.30, 0.25) contrast = self._rnd(0.6, 1.1) img_result = image_np + brightness img_result = 0.5 + contrast * (img_result - 0.5) label_result = label_np + brightness label_result = 0.5 + contrast * (label_result - 0.5) return numpy.clip(img_result, 0.0, 1.0),numpy.clip(label_result, 0.0, 1.0) def _rnd(self, min_value, max_value): return (max_value - min_value) * numpy.random.rand() + min_value
Editor is loading...