Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
226 B
3
Indexable
Never
def ewc(self, loss_fn, lam=10):
    def wrapper(*args, **kwargs):
        loss = loss_fn(*args, **kwargs)
        if self.fisher_matrix is not None:
            loss += self.ewc_loss(lam)
        return loss
    return wrapper