Untitled
unknown
diff
2 years ago
15 kB
4
Indexable
diff --git a/qi_lib/core/record/components.py b/qi_lib/core/record/components.py index df65a68..1afe61a 100644 --- a/qi_lib/core/record/components.py +++ b/qi_lib/core/record/components.py @@ -2,7 +2,7 @@ __all__ = ["Component", "ImagePathComponent", "SizeComponent", "BBoxesComponent" "ClassMapComponent", "ImageComponent", "TransformationsComponent"] import copy -from abc import ABC +from abc import ABC, abstractmethod from collections import namedtuple from logging import Logger from pathlib import Path @@ -310,28 +310,63 @@ class ScoresComponent(Component): return f"{self.__class__.__name__}; scores:{self.scores}" + +class Transform(ABC): + def __init__(self, transform_fn) -> None: + self.transform_fn = transform_fn + + + @abstractmethod + def __call__(self, *args: Any, **kwds: Any) -> Composite: + ... + + +class SingleTransform(Transform): + + # override the call method + def __call__(self, record: Composite) -> Composite: + return self.transform_fn(record) + +class BatchTransform(Transform): + + # override the call method + def __call__(self, records: List[Composite]) -> List[Composite]: + return self.transform_fn(records) + + class TransformationsComponent(Component): # TODO add torchvision suport - transformations: Any + transformations: list[Transform] - def __init__(self, transformations): + def __init__(self, transformations: list[Transform]): super().__init__() - self.transformations = transformations + self.transformations = list[Transform] self.composite = None def apply_transformations(self): # TODO add support for object detection and segmentation transformed = self.composite.transformations(image=self.composite.image) return {'image': transformed['image'], 'labels' : self.composite.labels} - + + + def transform_batch(self, records: List[Composite]) -> List[Composite]: + for transform in self.transformations: + if isinstance(transform, SingleTransform): + records = [transform(record) for record in records] + elif isinstance(transform, BatchTransform): + records = transform(records) + else: + raise TypeError(f"Unknown transformation type {type(transform)}") + return records def transform_record(self): # if isinstance(transformations , albumentations): #else # torchvision - - record = copy.deepcopy(self.composite) - record.image = record.transformations(image=record.composite.image)['image'] + record = self.composite + for transform in self.transformations: + assert isinstance(transform, SingleTransform) + record = transform(record) return record def __repr__(self): diff --git a/qi_lib/datamodules/base_datamodule.py b/qi_lib/datamodules/base_datamodule.py index a912342..72e401c 100644 --- a/qi_lib/datamodules/base_datamodule.py +++ b/qi_lib/datamodules/base_datamodule.py @@ -4,6 +4,7 @@ from pytorch_lightning import LightningDataModule from torch import Generator from torch.utils.data import DataLoader, Dataset# random_split from qi_lib.core.data.utils import random_split +import functools from qi_lib.core.data.placeholders import NotImplementedDataset from qi_lib.core.data.types import DatasetType @@ -28,6 +29,14 @@ class AddTransformsDataset(Dataset): return record +def transform_collate_fn(records: list[Composite], model_build_batch_fn): + # First apply transforms to batch and then call model's build batch function + + # transforms should be the same for all records in a batch + transforms = records[0].get_component() + records = transforms.transform_batch(records) + return records, model_build_batch_fn(records) + class BaseDataModule(LightningDataModule): def __init__(self, datasets: Union[list[Dataset], Dataset], batch_size: int, transforms=(None, None), model=None, split_ratio: Tuple[float, float, float] = (0.8, 0.2, 0.0), num_workers: int = 8, @@ -90,12 +99,12 @@ class BaseDataModule(LightningDataModule): def train_dataloader(self) -> DataLoader: # TODO add an option to override model return DataLoader(self.train_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers, - collate_fn=self.model.build_train_batch, drop_last=self.hparams.drop_last) + collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_train_batch), drop_last=self.hparams.drop_last) def val_dataloader(self) -> DataLoader: return DataLoader(self.val_ds, self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, - collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last) + collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last) def test_dataloader(self) -> DataLoader: return DataLoader(self.test_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers, - collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last) + collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last) diff --git a/qi_lib/models/classification/timm_wrapper.py b/qi_lib/models/classification/timm_wrapper.py index 6417da5..5921c89 100644 --- a/qi_lib/models/classification/timm_wrapper.py +++ b/qi_lib/models/classification/timm_wrapper.py @@ -24,18 +24,13 @@ class TimmWrapper(nn.Module): data = {} # TODO this is hacky AF - do something about this :). for record in records: - if hasattr(record, 'transform_record'): - data = record.apply_transformations() - # image, label = TimmWrapper.to_model_format(record.transform_record()) - else: - image, label = TimmWrapper.to_model_format(record) - images.append(data['image']) + images.append(torch.from_numpy(data['image'] / 255.0).permute(2, 0, 1)) labels.append(torch.tensor(data['labels'])) return torch.stack(images, 0), torch.cat(labels, 0) return torch.stack(images, 0), torch.cat(labels, 0) diff --git a/qi_lib/core/record/components.py b/qi_lib/core/record/components.py index df65a68..1afe61a 100644 --- a/qi_lib/core/record/components.py +++ b/qi_lib/core/record/components.py @@ -2,7 +2,7 @@ __all__ = ["Component", "ImagePathComponent", "SizeComponent", "BBoxesComponent" "ClassMapComponent", "ImageComponent", "TransformationsComponent"] import copy -from abc import ABC +from abc import ABC, abstractmethod from collections import namedtuple from logging import Logger from pathlib import Path @@ -310,28 +310,63 @@ class ScoresComponent(Component): return f"{self.__class__.__name__}; scores:{self.scores}" + +class Transform(ABC): + def __init__(self, transform_fn) -> None: + self.transform_fn = transform_fn + + + @abstractmethod + def __call__(self, *args: Any, **kwds: Any) -> Composite: + ... + + +class SingleTransform(Transform): + + # override the call method + def __call__(self, record: Composite) -> Composite: + return self.transform_fn(record) + +class BatchTransform(Transform): + + # override the call method + def __call__(self, records: List[Composite]) -> List[Composite]: + return self.transform_fn(records) + + class TransformationsComponent(Component): # TODO add torchvision suport - transformations: Any + transformations: list[Transform] - def __init__(self, transformations): + def __init__(self, transformations: list[Transform]): super().__init__() - self.transformations = transformations + self.transformations = list[Transform] self.composite = None def apply_transformations(self): # TODO add support for object detection and segmentation transformed = self.composite.transformations(image=self.composite.image) return {'image': transformed['image'], 'labels' : self.composite.labels} - + + + def transform_batch(self, records: List[Composite]) -> List[Composite]: + for transform in self.transformations: + if isinstance(transform, SingleTransform): + records = [transform(record) for record in records] + elif isinstance(transform, BatchTransform): + records = transform(records) + else: + raise TypeError(f"Unknown transformation type {type(transform)}") + return records def transform_record(self): # if isinstance(transformations , albumentations): #else # torchvision - - record = copy.deepcopy(self.composite) - record.image = record.transformations(image=record.composite.image)['image'] + record = self.composite + for transform in self.transformations: + assert isinstance(transform, SingleTransform) + record = transform(record) return record def __repr__(self): diff --git a/qi_lib/datamodules/base_datamodule.py b/qi_lib/datamodules/base_datamodule.py index a912342..72e401c 100644 --- a/qi_lib/datamodules/base_datamodule.py +++ b/qi_lib/datamodules/base_datamodule.py @@ -4,6 +4,7 @@ from pytorch_lightning import LightningDataModule from torch import Generator from torch.utils.data import DataLoader, Dataset# random_split from qi_lib.core.data.utils import random_split +import functools from qi_lib.core.data.placeholders import NotImplementedDataset from qi_lib.core.data.types import DatasetType @@ -28,6 +29,14 @@ class AddTransformsDataset(Dataset): return record +def transform_collate_fn(records: list[Composite], model_build_batch_fn): + # First apply transforms to batch and then call model's build batch function + + # transforms should be the same for all records in a batch + transforms = records[0].get_component() + records = transforms.transform_batch(records) + return records, model_build_batch_fn(records) + class BaseDataModule(LightningDataModule): def __init__(self, datasets: Union[list[Dataset], Dataset], batch_size: int, transforms=(None, None), model=None, split_ratio: Tuple[float, float, float] = (0.8, 0.2, 0.0), num_workers: int = 8, @@ -90,12 +99,12 @@ class BaseDataModule(LightningDataModule): def train_dataloader(self) -> DataLoader: # TODO add an option to override model return DataLoader(self.train_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers, - collate_fn=self.model.build_train_batch, drop_last=self.hparams.drop_last) + collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_train_batch), drop_last=self.hparams.drop_last) def val_dataloader(self) -> DataLoader: return DataLoader(self.val_ds, self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, - collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last) + collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last) def test_dataloader(self) -> DataLoader: return DataLoader(self.test_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers, - collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last) + collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last) diff --git a/qi_lib/models/classification/timm_wrapper.py b/qi_lib/models/classification/timm_wrapper.py index 6417da5..5921c89 100644 --- a/qi_lib/models/classification/timm_wrapper.py +++ b/qi_lib/models/classification/timm_wrapper.py @@ -24,18 +24,13 @@ class TimmWrapper(nn.Module): data = {} # TODO this is hacky AF - do something about this :). for record in records: - if hasattr(record, 'transform_record'): - data = record.apply_transformations() - # image, label = TimmWrapper.to_model_format(record.transform_record()) - else: - image, label = TimmWrapper.to_model_format(record) - images.append(data['image']) + images.append(torch.from_numpy(data['image'] / 255.0).permute(2, 0, 1)) labels.append(torch.tensor(data['labels'])) return torch.stack(images, 0), torch.cat(labels, 0) @staticmethod def build_val_batch(records: list[Composite]): - return TimmWrapper.build_train_batch(records), records + return TimmWrapper.build_train_batch(records) diff --git a/qi_lib/modules/classification_module.py b/qi_lib/modules/classification_module.py index 16d1aeb..ce0a628 100644 --- a/qi_lib/modules/classification_module.py +++ b/qi_lib/modules/classification_module.py @@ -20,19 +20,25 @@ class LightningClassification(LightningModule): return self.model(x) def training_step(self, batch, batch_idx): - xb, yb = batch + records, (xb, yb) = batch pred = self.model(xb) loss = self.loss_function(pred, yb) self.train_accuracy(pred, yb) self.log('train/loss', loss, on_step=True, on_epoch=False) + + + pred = self.model.predictions_from_model_format(pred) + # TODO: log predictions + # self.log_predictions(pred, yb, records, batch_idx, 'train') + return loss def training_epoch_end(self, outs): self.log('train/acc', self.train_accuracy) def validation_step(self, batch, batch_idx): - (xb, yb), records = batch + records, (xb, yb) = batch pred = self.model(xb) loss = self.loss_function(pred, yb) @@ -41,6 +47,10 @@ class LightningClassification(LightningModule): self.log('val/loss', loss, on_step=False, on_epoch=True) #TODO add visualization and logging + pred = self.model.predictions_from_model_format(pred) + # TODO: log predictions + # self.log_predictions(pred, yb, records, batch_idx, 'val') + def validation_epoch_end(self, outs): self.log('val/acc', self.val_accuracy)
Editor is loading...