Untitled

mail@pastecode.io avatar
unknown
python
18 days ago
944 B
2
Indexable
Never
for idx in range(2):  # Assuming you have at least 2 iterations
    grad_at_0 = self.hook_g[0]
    grad_at_1 = self.hook_g[0].pow(2)
    
    logging.warning(f"Iteration {idx}")
    logging.warning(f"grad_at_0 stats: mean={grad_at_0.mean():.6f}, std={grad_at_0.std():.6f}, min={grad_at_0.min():.6f}, max={grad_at_0.max():.6f}")
    logging.warning(f"grad_at_1 stats: mean={grad_at_1.mean():.6f}, std={grad_at_1.std():.6f}, min={grad_at_1.min():.6f}, max={grad_at_1.max():.6f}")
    
    logging.warning(f"Are grad_at_0 and grad_at_1 close? {torch.allclose(grad_at_0, grad_at_1, rtol=1e-5, atol=1e-8)}")
    
    # If you still want to compare unique values
    unique_0 = torch.unique(grad_at_0)
    unique_1 = torch.unique(grad_at_1)
    logging.warning(f"Number of unique values: grad_at_0={len(unique_0)}, grad_at_1={len(unique_1)}")
    logging.warning(f"Are unique values close? {torch.allclose(unique_0, unique_1, rtol=1e-5, atol=1e-8)}")
Leave a Comment