Untitled
unknown
python
a year ago
944 B
11
Indexable
for idx in range(2): # Assuming you have at least 2 iterations
grad_at_0 = self.hook_g[0]
grad_at_1 = self.hook_g[0].pow(2)
logging.warning(f"Iteration {idx}")
logging.warning(f"grad_at_0 stats: mean={grad_at_0.mean():.6f}, std={grad_at_0.std():.6f}, min={grad_at_0.min():.6f}, max={grad_at_0.max():.6f}")
logging.warning(f"grad_at_1 stats: mean={grad_at_1.mean():.6f}, std={grad_at_1.std():.6f}, min={grad_at_1.min():.6f}, max={grad_at_1.max():.6f}")
logging.warning(f"Are grad_at_0 and grad_at_1 close? {torch.allclose(grad_at_0, grad_at_1, rtol=1e-5, atol=1e-8)}")
# If you still want to compare unique values
unique_0 = torch.unique(grad_at_0)
unique_1 = torch.unique(grad_at_1)
logging.warning(f"Number of unique values: grad_at_0={len(unique_0)}, grad_at_1={len(unique_1)}")
logging.warning(f"Are unique values close? {torch.allclose(unique_0, unique_1, rtol=1e-5, atol=1e-8)}")Editor is loading...
Leave a Comment