Untitled

 avatar
unknown
python
5 months ago
1.6 kB
3
Indexable
import torch
from torch.autograd import Function
from torch.autograd import gradcheck

from typing import Tuple, Optional, Union

class Summation(Function):

    @staticmethod
    def forward(input: torch.Tensor, dim: Tuple[int]):
        # Student code start 
        return torch.sum(input, dim)
        # Student code end

    @staticmethod
    def setup_context(ctx, inputs, output):
        # Student code start
        input, dim = inputs
        ctx.save_for_backward(input) 
        ctx.dim = dim
        # Student code end

    @staticmethod
    def backward(ctx, grad_output):
        # Student code start
        dim = ctx.dim
        input, = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        
        # if no dim specified, the output is a scalar. Every input element has gradient = 1, 
        # so we can just simply expand grad_ouput to input shape. -> grad_input = grad_output.expand(input.shape)
        # For example, given a input of shape (3, 4, 5), dim = (0, 1), the output will be (5). 
        # Input elements are added along with axis 0 and 1.
        # We want the grad_output to broadcast to its corresponding axis. 
        # so we need to unsqueeze back to the original dimension size.
        # then, 
        
        # input: (3, 4, 5) -> reduced: (5) -> unsqueeze: (1, 1, 5)
        for d in sorted(dim):
            grad_output = grad_output.unsqueeze(d) 
        # broadcast (1, 1, 5) to (3, 4, 5)
        grad_input = grad_output.expand(input.shape)
       
        return grad_input, grad_weight, grad_bias
        # Student code end
Editor is loading...
Leave a Comment