Untitled

 avatar
unknown
python
3 years ago
226 B
8
Indexable
def ewc(self, loss_fn, lam=10):
    def wrapper(*args, **kwargs):
        loss = loss_fn(*args, **kwargs)
        if self.fisher_matrix is not None:
            loss += self.ewc_loss(lam)
        return loss
    return wrapper
Editor is loading...