Untitled
unknown
diff
3 years ago
8.2 kB
7
Indexable
diff --git a/qi_lib/core/record/components.py b/qi_lib/core/record/components.py
index df65a68..1afe61a 100644
--- a/qi_lib/core/record/components.py
+++ b/qi_lib/core/record/components.py
@@ -2,7 +2,7 @@ __all__ = ["Component", "ImagePathComponent", "SizeComponent", "BBoxesComponent"
"ClassMapComponent", "ImageComponent", "TransformationsComponent"]
import copy
-from abc import ABC
+from abc import ABC, abstractmethod
from collections import namedtuple
from logging import Logger
from pathlib import Path
@@ -310,28 +310,63 @@ class ScoresComponent(Component):
return f"{self.__class__.__name__}; scores:{self.scores}"
+
+class Transform(ABC):
+ def __init__(self, transform_fn) -> None:
+ self.transform_fn = transform_fn
+
+
+ @abstractmethod
+ def __call__(self, *args: Any, **kwds: Any) -> Composite:
+ ...
+
+
+class SingleTransform(Transform):
+
+ # override the call method
+ def __call__(self, record: Composite) -> Composite:
+ return self.transform_fn(record)
+
+class BatchTransform(Transform):
+
+ # override the call method
+ def __call__(self, records: List[Composite]) -> List[Composite]:
+ return self.transform_fn(records)
+
+
class TransformationsComponent(Component):
# TODO add torchvision suport
- transformations: Any
+ transformations: list[Transform]
- def __init__(self, transformations):
+ def __init__(self, transformations: list[Transform]):
super().__init__()
- self.transformations = transformations
+ self.transformations = list[Transform]
self.composite = None
def apply_transformations(self):
# TODO add support for object detection and segmentation
transformed = self.composite.transformations(image=self.composite.image)
return {'image': transformed['image'], 'labels' : self.composite.labels}
-
+
+
+ def transform_batch(self, records: List[Composite]) -> List[Composite]:
+ for transform in self.transformations:
+ if isinstance(transform, SingleTransform):
+ records = [transform(record) for record in records]
+ elif isinstance(transform, BatchTransform):
+ records = transform(records)
+ else:
+ raise TypeError(f"Unknown transformation type {type(transform)}")
+ return records
def transform_record(self):
# if isinstance(transformations , albumentations):
#else
# torchvision
-
- record = copy.deepcopy(self.composite)
- record.image = record.transformations(image=record.composite.image)['image']
+ record = self.composite
+ for transform in self.transformations:
+ assert isinstance(transform, SingleTransform)
+ record = transform(record)
return record
def __repr__(self):
diff --git a/qi_lib/datamodules/base_datamodule.py b/qi_lib/datamodules/base_datamodule.py
index a912342..72e401c 100644
--- a/qi_lib/datamodules/base_datamodule.py
+++ b/qi_lib/datamodules/base_datamodule.py
@@ -4,6 +4,7 @@ from pytorch_lightning import LightningDataModule
from torch import Generator
from torch.utils.data import DataLoader, Dataset# random_split
from qi_lib.core.data.utils import random_split
+import functools
from qi_lib.core.data.placeholders import NotImplementedDataset
from qi_lib.core.data.types import DatasetType
@@ -28,6 +29,14 @@ class AddTransformsDataset(Dataset):
return record
+def transform_collate_fn(records: list[Composite], model_build_batch_fn):
+ # First apply transforms to batch and then call model's build batch function
+
+ # transforms should be the same for all records in a batch
+ transforms = records[0].get_component()
+ records = transforms.transform_batch(records)
+ return records, model_build_batch_fn(records)
+
class BaseDataModule(LightningDataModule):
def __init__(self, datasets: Union[list[Dataset], Dataset], batch_size: int, transforms=(None, None), model=None,
split_ratio: Tuple[float, float, float] = (0.8, 0.2, 0.0), num_workers: int = 8,
@@ -90,12 +99,12 @@ class BaseDataModule(LightningDataModule):
def train_dataloader(self) -> DataLoader:
# TODO add an option to override model
return DataLoader(self.train_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers,
- collate_fn=self.model.build_train_batch, drop_last=self.hparams.drop_last)
+ collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_train_batch), drop_last=self.hparams.drop_last)
def val_dataloader(self) -> DataLoader:
return DataLoader(self.val_ds, self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers,
- collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last)
+ collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last)
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers,
- collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last)
+ collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last)
diff --git a/qi_lib/models/classification/timm_wrapper.py b/qi_lib/models/classification/timm_wrapper.py
index 6417da5..5921c89 100644
--- a/qi_lib/models/classification/timm_wrapper.py
+++ b/qi_lib/models/classification/timm_wrapper.py
@@ -24,18 +24,13 @@ class TimmWrapper(nn.Module):
data = {}
# TODO this is hacky AF - do something about this :).
for record in records:
- if hasattr(record, 'transform_record'):
- data = record.apply_transformations()
- # image, label = TimmWrapper.to_model_format(record.transform_record())
- else:
- image, label = TimmWrapper.to_model_format(record)
- images.append(data['image'])
+ images.append(torch.from_numpy(data['image'] / 255.0).permute(2, 0, 1))
labels.append(torch.tensor(data['labels']))
return torch.stack(images, 0), torch.cat(labels, 0)
@staticmethod
def build_val_batch(records: list[Composite]):
- return TimmWrapper.build_train_batch(records), records
+ return TimmWrapper.build_train_batch(records)
diff --git a/qi_lib/modules/classification_module.py b/qi_lib/modules/classification_module.py
index 16d1aeb..ce0a628 100644
--- a/qi_lib/modules/classification_module.py
+++ b/qi_lib/modules/classification_module.py
@@ -20,19 +20,25 @@ class LightningClassification(LightningModule):
return self.model(x)
def training_step(self, batch, batch_idx):
- xb, yb = batch
+ records, (xb, yb) = batch
pred = self.model(xb)
loss = self.loss_function(pred, yb)
self.train_accuracy(pred, yb)
self.log('train/loss', loss, on_step=True, on_epoch=False)
+
+
+ pred = self.model.predictions_from_model_format(pred)
+ # TODO: log predictions
+ # self.log_predictions(pred, yb, records, batch_idx, 'train')
+
return loss
def training_epoch_end(self, outs):
self.log('train/acc', self.train_accuracy)
def validation_step(self, batch, batch_idx):
- (xb, yb), records = batch
+ records, (xb, yb) = batch
pred = self.model(xb)
loss = self.loss_function(pred, yb)
@@ -41,6 +47,10 @@ class LightningClassification(LightningModule):
self.log('val/loss', loss, on_step=False, on_epoch=True)
#TODO add visualization and logging
+ pred = self.model.predictions_from_model_format(pred)
+ # TODO: log predictions
+ # self.log_predictions(pred, yb, records, batch_idx, 'val')
+
def validation_epoch_end(self, outs):
self.log('val/acc', self.val_accuracy)
Editor is loading...