Untitled
unknown
python
5 months ago
1.6 kB
3
Indexable
import torch from torch.autograd import Function from torch.autograd import gradcheck from typing import Tuple, Optional, Union class Summation(Function): @staticmethod def forward(input: torch.Tensor, dim: Tuple[int]): # Student code start return torch.sum(input, dim) # Student code end @staticmethod def setup_context(ctx, inputs, output): # Student code start input, dim = inputs ctx.save_for_backward(input) ctx.dim = dim # Student code end @staticmethod def backward(ctx, grad_output): # Student code start dim = ctx.dim input, = ctx.saved_tensors grad_input = grad_weight = grad_bias = None # if no dim specified, the output is a scalar. Every input element has gradient = 1, # so we can just simply expand grad_ouput to input shape. -> grad_input = grad_output.expand(input.shape) # For example, given a input of shape (3, 4, 5), dim = (0, 1), the output will be (5). # Input elements are added along with axis 0 and 1. # We want the grad_output to broadcast to its corresponding axis. # so we need to unsqueeze back to the original dimension size. # then, # input: (3, 4, 5) -> reduced: (5) -> unsqueeze: (1, 1, 5) for d in sorted(dim): grad_output = grad_output.unsqueeze(d) # broadcast (1, 1, 5) to (3, 4, 5) grad_input = grad_output.expand(input.shape) return grad_input, grad_weight, grad_bias # Student code end
Editor is loading...
Leave a Comment