Untitled
unknown
python
3 years ago
4.3 kB
10
Indexable
from tsl.engines import Predictor
from tsl.data import Data
# Custom predictor that takes into account the auxiliary loss from MinCutPool
class CustomPredictor(Predictor):
def predict_batch(self, batch: Data,
preprocess: bool = False,
postprocess: bool = True,
return_target: bool = False,
**forward_kwargs):
""""""
inputs, targets, mask, transform = self._unpack_batch(batch)
if preprocess:
for key, trans in transform.items():
if key in inputs:
inputs[key] = trans.transform(inputs[key])
if forward_kwargs is None:
forward_kwargs = dict()
y_hat, aux_loss = self.forward(**inputs, **forward_kwargs)
# Rescale outputs
if postprocess:
trans = transform.get('y')
if trans is not None:
y_hat = trans.inverse_transform(y_hat)
if return_target:
y = targets.get('y')
return y, y_hat, mask, aux_loss
return y_hat, aux_loss
def predict_step(self, batch, batch_idx, dataloader_idx=None):
""""""
# Unpack batch
x, y, mask, transform = self._unpack_batch(batch)
# Make predictions
y_hat, _ = self.predict_batch(batch, preprocess=False, postprocess=True)
output = dict(**y, y_hat=y_hat)
if mask is not None:
output['mask'] = mask
return output
def training_step(self, batch, batch_idx):
""""""
y = y_loss = batch.y
mask = batch.get('mask')
# Compute predictions and compute loss
y_hat_loss, aux_loss = self.predict_batch(batch, preprocess=False,
postprocess=not self.scale_target)
y_hat = y_hat_loss.detach()
# Scale target and output, eventually
if self.scale_target:
y_loss = batch.transform['y'].transform(y)
y_hat = batch.transform['y'].inverse_transform(y_hat)
# Compute loss
loss = self.loss_fn(y_hat_loss, y_loss, mask) + aux_loss
# Logging
self.train_metrics.update(y_hat, y, mask)
self.log_metrics(self.train_metrics, batch_size=batch.batch_size)
self.log_loss('train', loss, batch_size=batch.batch_size)
return loss
def validation_step(self, batch, batch_idx):
""""""
y = y_loss = batch.y
mask = batch.get('mask')
# Compute predictions
y_hat_loss, aux_loss = self.predict_batch(batch, preprocess=False,
postprocess=not self.scale_target)
y_hat = y_hat_loss.detach()
# Scale target and output, eventually
if self.scale_target:
y_loss = batch.transform['y'].transform(y)
y_hat = batch.transform['y'].inverse_transform(y_hat)
# Compute loss
val_loss = self.loss_fn(y_hat_loss, y_loss, mask) + aux_loss
# Logging
self.val_metrics.update(y_hat, y, mask)
self.log_metrics(self.val_metrics, batch_size=batch.batch_size)
self.log_loss('val', val_loss, batch_size=batch.batch_size)
return val_loss
def test_step(self, batch, batch_idx):
""""""
# Compute outputs and rescale
y_hat, aux_loss = self.predict_batch(batch, preprocess=False, postprocess=True)
y, mask = batch.y, batch.get('mask')
test_loss = self.loss_fn(y_hat, y, mask) + aux_loss
# Logging
self.test_metrics.update(y_hat.detach(), y, mask)
self.log_metrics(self.test_metrics, batch_size=batch.batch_size)
self.log_loss('test', test_loss, batch_size=batch.batch_size)
return test_loss
def compute_metrics(self, batch, preprocess=False, postprocess=True):
""""""
# Compute outputs and rescale
y_hat, _ = self.predict_batch(batch, preprocess, postprocess)
y, mask = batch.y, batch.get('mask')
self.test_metrics.update(y_hat.detach(), y, mask)
metrics_dict = self.test_metrics.compute()
self.test_metrics.reset()
return metrics_dict, y_hatEditor is loading...