nord vpnnord vpn
Ad

Untitled

mail@pastecode.io avatar
unknown
python
7 months ago
4.3 kB
3
Indexable
Never
from tsl.engines import Predictor
from tsl.data import Data

# Custom predictor that takes into account the auxiliary loss from MinCutPool
class CustomPredictor(Predictor):
    
    def predict_batch(self, batch: Data,
                      preprocess: bool = False, 
                      postprocess: bool = True,
                      return_target: bool = False,
                      **forward_kwargs):
        """"""
        inputs, targets, mask, transform = self._unpack_batch(batch)
        if preprocess:
            for key, trans in transform.items():
                if key in inputs:
                    inputs[key] = trans.transform(inputs[key])

        if forward_kwargs is None:
            forward_kwargs = dict()
        y_hat, aux_loss = self.forward(**inputs, **forward_kwargs)
        # Rescale outputs
        if postprocess:
            trans = transform.get('y')
            if trans is not None:
                y_hat = trans.inverse_transform(y_hat)
        if return_target:
            y = targets.get('y')
            return y, y_hat, mask, aux_loss
        return y_hat, aux_loss

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        """"""
        # Unpack batch
        x, y, mask, transform = self._unpack_batch(batch)

        # Make predictions
        y_hat, _ = self.predict_batch(batch, preprocess=False, postprocess=True)

        output = dict(**y, y_hat=y_hat)
        if mask is not None:
            output['mask'] = mask
        return output

    def training_step(self, batch, batch_idx):
        """"""
        y = y_loss = batch.y
        mask = batch.get('mask')

        # Compute predictions and compute loss
        y_hat_loss, aux_loss = self.predict_batch(batch, preprocess=False,
                                        postprocess=not self.scale_target)
        y_hat = y_hat_loss.detach()

        # Scale target and output, eventually
        if self.scale_target:
            y_loss = batch.transform['y'].transform(y)
            y_hat = batch.transform['y'].inverse_transform(y_hat)

        # Compute loss
        loss = self.loss_fn(y_hat_loss, y_loss, mask) + aux_loss

        # Logging
        self.train_metrics.update(y_hat, y, mask)
        self.log_metrics(self.train_metrics, batch_size=batch.batch_size)
        self.log_loss('train', loss, batch_size=batch.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        """"""
        y = y_loss = batch.y
        mask = batch.get('mask')

        # Compute predictions
        y_hat_loss, aux_loss = self.predict_batch(batch, preprocess=False,
                                        postprocess=not self.scale_target)
        y_hat = y_hat_loss.detach()

        # Scale target and output, eventually
        if self.scale_target:
            y_loss = batch.transform['y'].transform(y)
            y_hat = batch.transform['y'].inverse_transform(y_hat)

        # Compute loss
        val_loss = self.loss_fn(y_hat_loss, y_loss, mask) + aux_loss

        # Logging
        self.val_metrics.update(y_hat, y, mask)
        self.log_metrics(self.val_metrics, batch_size=batch.batch_size)
        self.log_loss('val', val_loss, batch_size=batch.batch_size)
        return val_loss

    def test_step(self, batch, batch_idx):
        """"""
        # Compute outputs and rescale
        y_hat, aux_loss = self.predict_batch(batch, preprocess=False, postprocess=True)

        y, mask = batch.y, batch.get('mask')
        test_loss = self.loss_fn(y_hat, y, mask) + aux_loss

        # Logging
        self.test_metrics.update(y_hat.detach(), y, mask)
        self.log_metrics(self.test_metrics, batch_size=batch.batch_size)
        self.log_loss('test', test_loss, batch_size=batch.batch_size)
        return test_loss
    
    def compute_metrics(self, batch, preprocess=False, postprocess=True):
        """"""
        # Compute outputs and rescale
        y_hat, _ = self.predict_batch(batch, preprocess, postprocess)
        y, mask = batch.y, batch.get('mask')
        self.test_metrics.update(y_hat.detach(), y, mask)
        metrics_dict = self.test_metrics.compute()
        self.test_metrics.reset()
        return metrics_dict, y_hat

nord vpnnord vpn
Ad