Untitled
unknown
python
3 years ago
226 B
11
Indexable
def ewc(self, loss_fn, lam=10):
def wrapper(*args, **kwargs):
loss = loss_fn(*args, **kwargs)
if self.fisher_matrix is not None:
loss += self.ewc_loss(lam)
return loss
return wrapperEditor is loading...