def ewc(self, loss_fn, lam=10): def wrapper(*args, **kwargs): loss = loss_fn(*args, **kwargs) if self.fisher_matrix is not None: loss += self.ewc_loss(lam) return loss return wrapper