Untitled

 avatar
unknown
diff
2 years ago
15 kB
4
Indexable
diff --git a/qi_lib/core/record/components.py b/qi_lib/core/record/components.py
index df65a68..1afe61a 100644
--- a/qi_lib/core/record/components.py
+++ b/qi_lib/core/record/components.py
@@ -2,7 +2,7 @@ __all__ = ["Component", "ImagePathComponent", "SizeComponent", "BBoxesComponent"
            "ClassMapComponent", "ImageComponent", "TransformationsComponent"]
 
 import copy
-from abc import ABC
+from abc import ABC, abstractmethod
 from collections import namedtuple
 from logging import Logger
 from pathlib import Path
@@ -310,28 +310,63 @@ class ScoresComponent(Component):
         return f"{self.__class__.__name__}; scores:{self.scores}"
 
 
+
+class Transform(ABC):
+    def __init__(self, transform_fn) -> None:
+        self.transform_fn = transform_fn
+
+
+    @abstractmethod
+    def __call__(self, *args: Any, **kwds: Any) -> Composite:
+        ...
+        
+
+class SingleTransform(Transform):
+
+    # override the call method
+    def __call__(self, record: Composite) -> Composite:
+        return self.transform_fn(record)
+
+class BatchTransform(Transform):
+
+    # override the call method
+    def __call__(self, records: List[Composite]) -> List[Composite]:
+        return self.transform_fn(records)
+
+
 class TransformationsComponent(Component):
     # TODO add torchvision suport
-    transformations: Any
+    transformations: list[Transform]
 
-    def __init__(self, transformations):
+    def __init__(self, transformations: list[Transform]):
         super().__init__()
-        self.transformations = transformations
+        self.transformations = list[Transform]
         self.composite = None
 
     def apply_transformations(self):
         # TODO add support for object detection and segmentation
         transformed = self.composite.transformations(image=self.composite.image)
         return {'image': transformed['image'], 'labels' : self.composite.labels}
-
+    
+
+    def transform_batch(self, records: List[Composite]) -> List[Composite]:
+        for transform in self.transformations:
+            if isinstance(transform, SingleTransform):
+                records = [transform(record) for record in records]
+            elif isinstance(transform, BatchTransform):
+                records = transform(records)
+            else:
+                raise TypeError(f"Unknown transformation type {type(transform)}")
+        return records
 
     def transform_record(self):
         # if isinstance(transformations , albumentations):
         #else
         # torchvision
-
-        record = copy.deepcopy(self.composite)
-        record.image = record.transformations(image=record.composite.image)['image']
+        record = self.composite
+        for transform in self.transformations:
+            assert isinstance(transform, SingleTransform)
+            record = transform(record)
         return record
 
     def __repr__(self):
diff --git a/qi_lib/datamodules/base_datamodule.py b/qi_lib/datamodules/base_datamodule.py
index a912342..72e401c 100644
--- a/qi_lib/datamodules/base_datamodule.py
+++ b/qi_lib/datamodules/base_datamodule.py
@@ -4,6 +4,7 @@ from pytorch_lightning import LightningDataModule
 from torch import Generator
 from torch.utils.data import DataLoader, Dataset# random_split
 from qi_lib.core.data.utils import random_split
+import functools
 
 from qi_lib.core.data.placeholders import NotImplementedDataset
 from qi_lib.core.data.types import DatasetType
@@ -28,6 +29,14 @@ class AddTransformsDataset(Dataset):
         return record
 
 
+def transform_collate_fn(records: list[Composite], model_build_batch_fn):
+    # First apply transforms to batch and then call model's build batch function
+
+    # transforms should be the same for all records in a batch
+    transforms = records[0].get_component()
+    records = transforms.transform_batch(records)
+    return records, model_build_batch_fn(records)
+
 class BaseDataModule(LightningDataModule):
     def __init__(self, datasets: Union[list[Dataset], Dataset], batch_size: int, transforms=(None, None), model=None,
                  split_ratio: Tuple[float, float, float] = (0.8, 0.2, 0.0), num_workers: int = 8,
@@ -90,12 +99,12 @@ class BaseDataModule(LightningDataModule):
     def train_dataloader(self) -> DataLoader:
         # TODO add an option to override model
         return DataLoader(self.train_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers,
-                          collate_fn=self.model.build_train_batch, drop_last=self.hparams.drop_last)
+                          collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_train_batch), drop_last=self.hparams.drop_last)
 
     def val_dataloader(self) -> DataLoader:
         return DataLoader(self.val_ds, self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers,
-                          collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last)
+                          collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last)
 
     def test_dataloader(self) -> DataLoader:
         return DataLoader(self.test_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers,
-                          collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last)
+                          collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last)
diff --git a/qi_lib/models/classification/timm_wrapper.py b/qi_lib/models/classification/timm_wrapper.py
index 6417da5..5921c89 100644
--- a/qi_lib/models/classification/timm_wrapper.py
+++ b/qi_lib/models/classification/timm_wrapper.py
@@ -24,18 +24,13 @@ class TimmWrapper(nn.Module):
         data = {}
         # TODO this is hacky AF - do something about this :).
         for record in records:
-            if hasattr(record, 'transform_record'):
-                data = record.apply_transformations()
-                # image, label = TimmWrapper.to_model_format(record.transform_record())
-            else:
-                image, label = TimmWrapper.to_model_format(record)
-            images.append(data['image'])
+            images.append(torch.from_numpy(data['image'] / 255.0).permute(2, 0, 1))
             labels.append(torch.tensor(data['labels']))
         return torch.stack(images, 0), torch.cat(labels, 0)
         return torch.stack(images, 0), torch.cat(labels, 0)
diff --git a/qi_lib/core/record/components.py b/qi_lib/core/record/components.py
index df65a68..1afe61a 100644
--- a/qi_lib/core/record/components.py
+++ b/qi_lib/core/record/components.py
@@ -2,7 +2,7 @@ __all__ = ["Component", "ImagePathComponent", "SizeComponent", "BBoxesComponent"
            "ClassMapComponent", "ImageComponent", "TransformationsComponent"]
 
 import copy
-from abc import ABC
+from abc import ABC, abstractmethod
 from collections import namedtuple
 from logging import Logger
 from pathlib import Path
@@ -310,28 +310,63 @@ class ScoresComponent(Component):
         return f"{self.__class__.__name__}; scores:{self.scores}"
 
 
+
+class Transform(ABC):
+    def __init__(self, transform_fn) -> None:
+        self.transform_fn = transform_fn
+
+
+    @abstractmethod
+    def __call__(self, *args: Any, **kwds: Any) -> Composite:
+        ...
+        
+
+class SingleTransform(Transform):
+
+    # override the call method
+    def __call__(self, record: Composite) -> Composite:
+        return self.transform_fn(record)
+
+class BatchTransform(Transform):
+
+    # override the call method
+    def __call__(self, records: List[Composite]) -> List[Composite]:
+        return self.transform_fn(records)
+
+
 class TransformationsComponent(Component):
     # TODO add torchvision suport
-    transformations: Any
+    transformations: list[Transform]
 
-    def __init__(self, transformations):
+    def __init__(self, transformations: list[Transform]):
         super().__init__()
-        self.transformations = transformations
+        self.transformations = list[Transform]
         self.composite = None
 
     def apply_transformations(self):
         # TODO add support for object detection and segmentation
         transformed = self.composite.transformations(image=self.composite.image)
         return {'image': transformed['image'], 'labels' : self.composite.labels}
-
+    
+
+    def transform_batch(self, records: List[Composite]) -> List[Composite]:
+        for transform in self.transformations:
+            if isinstance(transform, SingleTransform):
+                records = [transform(record) for record in records]
+            elif isinstance(transform, BatchTransform):
+                records = transform(records)
+            else:
+                raise TypeError(f"Unknown transformation type {type(transform)}")
+        return records
 
     def transform_record(self):
         # if isinstance(transformations , albumentations):
         #else
         # torchvision
-
-        record = copy.deepcopy(self.composite)
-        record.image = record.transformations(image=record.composite.image)['image']
+        record = self.composite
+        for transform in self.transformations:
+            assert isinstance(transform, SingleTransform)
+            record = transform(record)
         return record
 
     def __repr__(self):
diff --git a/qi_lib/datamodules/base_datamodule.py b/qi_lib/datamodules/base_datamodule.py
index a912342..72e401c 100644
--- a/qi_lib/datamodules/base_datamodule.py
+++ b/qi_lib/datamodules/base_datamodule.py
@@ -4,6 +4,7 @@ from pytorch_lightning import LightningDataModule
 from torch import Generator
 from torch.utils.data import DataLoader, Dataset# random_split
 from qi_lib.core.data.utils import random_split
+import functools
 
 from qi_lib.core.data.placeholders import NotImplementedDataset
 from qi_lib.core.data.types import DatasetType
@@ -28,6 +29,14 @@ class AddTransformsDataset(Dataset):
         return record
 
 
+def transform_collate_fn(records: list[Composite], model_build_batch_fn):
+    # First apply transforms to batch and then call model's build batch function
+
+    # transforms should be the same for all records in a batch
+    transforms = records[0].get_component()
+    records = transforms.transform_batch(records)
+    return records, model_build_batch_fn(records)
+
 class BaseDataModule(LightningDataModule):
     def __init__(self, datasets: Union[list[Dataset], Dataset], batch_size: int, transforms=(None, None), model=None,
                  split_ratio: Tuple[float, float, float] = (0.8, 0.2, 0.0), num_workers: int = 8,
@@ -90,12 +99,12 @@ class BaseDataModule(LightningDataModule):
     def train_dataloader(self) -> DataLoader:
         # TODO add an option to override model
         return DataLoader(self.train_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers,
-                          collate_fn=self.model.build_train_batch, drop_last=self.hparams.drop_last)
+                          collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_train_batch), drop_last=self.hparams.drop_last)
 
     def val_dataloader(self) -> DataLoader:
         return DataLoader(self.val_ds, self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers,
-                          collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last)
+                          collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last)
 
     def test_dataloader(self) -> DataLoader:
         return DataLoader(self.test_ds, self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers,
-                          collate_fn=self.model.build_val_batch, drop_last=self.hparams.drop_last)
+                          collate_fn=functools.partial(transform_collate_fn, model_build_batch_fn=self.model.build_val_batch), drop_last=self.hparams.drop_last)
diff --git a/qi_lib/models/classification/timm_wrapper.py b/qi_lib/models/classification/timm_wrapper.py
index 6417da5..5921c89 100644
--- a/qi_lib/models/classification/timm_wrapper.py
+++ b/qi_lib/models/classification/timm_wrapper.py
@@ -24,18 +24,13 @@ class TimmWrapper(nn.Module):
         data = {}
         # TODO this is hacky AF - do something about this :).
         for record in records:
-            if hasattr(record, 'transform_record'):
-                data = record.apply_transformations()
-                # image, label = TimmWrapper.to_model_format(record.transform_record())
-            else:
-                image, label = TimmWrapper.to_model_format(record)
-            images.append(data['image'])
+            images.append(torch.from_numpy(data['image'] / 255.0).permute(2, 0, 1))
             labels.append(torch.tensor(data['labels']))
         return torch.stack(images, 0), torch.cat(labels, 0)
 
     @staticmethod
     def build_val_batch(records: list[Composite]):
-        return TimmWrapper.build_train_batch(records), records
+        return TimmWrapper.build_train_batch(records)
 
 
 
diff --git a/qi_lib/modules/classification_module.py b/qi_lib/modules/classification_module.py
index 16d1aeb..ce0a628 100644
--- a/qi_lib/modules/classification_module.py
+++ b/qi_lib/modules/classification_module.py
@@ -20,19 +20,25 @@ class LightningClassification(LightningModule):
         return self.model(x)
 
     def training_step(self, batch, batch_idx):
-        xb, yb = batch
+        records, (xb, yb) = batch
         pred = self.model(xb)
 
         loss = self.loss_function(pred, yb)
         self.train_accuracy(pred, yb)
         self.log('train/loss', loss, on_step=True, on_epoch=False)
+
+
+        pred = self.model.predictions_from_model_format(pred)
+        # TODO: log predictions
+        # self.log_predictions(pred, yb, records, batch_idx, 'train')
+
         return loss
 
     def training_epoch_end(self, outs):
         self.log('train/acc', self.train_accuracy)
 
     def validation_step(self, batch, batch_idx):
-        (xb, yb), records = batch
+        records, (xb, yb) = batch
         pred = self.model(xb)
 
         loss = self.loss_function(pred, yb)
@@ -41,6 +47,10 @@ class LightningClassification(LightningModule):
         self.log('val/loss', loss, on_step=False, on_epoch=True)
         #TODO add visualization and logging
 
+        pred = self.model.predictions_from_model_format(pred)
+        # TODO: log predictions
+        # self.log_predictions(pred, yb, records, batch_idx, 'val')
+
     def validation_epoch_end(self, outs):
         self.log('val/acc', self.val_accuracy)
Editor is loading...