Untitled
unknown
python
a year ago
1.6 kB
5
Indexable
import torch
from torch.autograd import Function
from torch.autograd import gradcheck
from typing import Tuple, Optional, Union
class Summation(Function):
@staticmethod
def forward(input: torch.Tensor, dim: Tuple[int]):
# Student code start
return torch.sum(input, dim)
# Student code end
@staticmethod
def setup_context(ctx, inputs, output):
# Student code start
input, dim = inputs
ctx.save_for_backward(input)
ctx.dim = dim
# Student code end
@staticmethod
def backward(ctx, grad_output):
# Student code start
dim = ctx.dim
input, = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# if no dim specified, the output is a scalar. Every input element has gradient = 1,
# so we can just simply expand grad_ouput to input shape. -> grad_input = grad_output.expand(input.shape)
# For example, given a input of shape (3, 4, 5), dim = (0, 1), the output will be (5).
# Input elements are added along with axis 0 and 1.
# We want the grad_output to broadcast to its corresponding axis.
# so we need to unsqueeze back to the original dimension size.
# then,
# input: (3, 4, 5) -> reduced: (5) -> unsqueeze: (1, 1, 5)
for d in sorted(dim):
grad_output = grad_output.unsqueeze(d)
# broadcast (1, 1, 5) to (3, 4, 5)
grad_input = grad_output.expand(input.shape)
return grad_input, grad_weight, grad_bias
# Student code endEditor is loading...
Leave a Comment