Untitled
unknown
python
7 months ago
4.3 kB
3
Indexable
Never
from tsl.engines import Predictor from tsl.data import Data # Custom predictor that takes into account the auxiliary loss from MinCutPool class CustomPredictor(Predictor): def predict_batch(self, batch: Data, preprocess: bool = False, postprocess: bool = True, return_target: bool = False, **forward_kwargs): """""" inputs, targets, mask, transform = self._unpack_batch(batch) if preprocess: for key, trans in transform.items(): if key in inputs: inputs[key] = trans.transform(inputs[key]) if forward_kwargs is None: forward_kwargs = dict() y_hat, aux_loss = self.forward(**inputs, **forward_kwargs) # Rescale outputs if postprocess: trans = transform.get('y') if trans is not None: y_hat = trans.inverse_transform(y_hat) if return_target: y = targets.get('y') return y, y_hat, mask, aux_loss return y_hat, aux_loss def predict_step(self, batch, batch_idx, dataloader_idx=None): """""" # Unpack batch x, y, mask, transform = self._unpack_batch(batch) # Make predictions y_hat, _ = self.predict_batch(batch, preprocess=False, postprocess=True) output = dict(**y, y_hat=y_hat) if mask is not None: output['mask'] = mask return output def training_step(self, batch, batch_idx): """""" y = y_loss = batch.y mask = batch.get('mask') # Compute predictions and compute loss y_hat_loss, aux_loss = self.predict_batch(batch, preprocess=False, postprocess=not self.scale_target) y_hat = y_hat_loss.detach() # Scale target and output, eventually if self.scale_target: y_loss = batch.transform['y'].transform(y) y_hat = batch.transform['y'].inverse_transform(y_hat) # Compute loss loss = self.loss_fn(y_hat_loss, y_loss, mask) + aux_loss # Logging self.train_metrics.update(y_hat, y, mask) self.log_metrics(self.train_metrics, batch_size=batch.batch_size) self.log_loss('train', loss, batch_size=batch.batch_size) return loss def validation_step(self, batch, batch_idx): """""" y = y_loss = batch.y mask = batch.get('mask') # Compute predictions y_hat_loss, aux_loss = self.predict_batch(batch, preprocess=False, postprocess=not self.scale_target) y_hat = y_hat_loss.detach() # Scale target and output, eventually if self.scale_target: y_loss = batch.transform['y'].transform(y) y_hat = batch.transform['y'].inverse_transform(y_hat) # Compute loss val_loss = self.loss_fn(y_hat_loss, y_loss, mask) + aux_loss # Logging self.val_metrics.update(y_hat, y, mask) self.log_metrics(self.val_metrics, batch_size=batch.batch_size) self.log_loss('val', val_loss, batch_size=batch.batch_size) return val_loss def test_step(self, batch, batch_idx): """""" # Compute outputs and rescale y_hat, aux_loss = self.predict_batch(batch, preprocess=False, postprocess=True) y, mask = batch.y, batch.get('mask') test_loss = self.loss_fn(y_hat, y, mask) + aux_loss # Logging self.test_metrics.update(y_hat.detach(), y, mask) self.log_metrics(self.test_metrics, batch_size=batch.batch_size) self.log_loss('test', test_loss, batch_size=batch.batch_size) return test_loss def compute_metrics(self, batch, preprocess=False, postprocess=True): """""" # Compute outputs and rescale y_hat, _ = self.predict_batch(batch, preprocess, postprocess) y, mask = batch.y, batch.get('mask') self.test_metrics.update(y_hat.detach(), y, mask) metrics_dict = self.test_metrics.compute() self.test_metrics.reset() return metrics_dict, y_hat