Untitled

 avatar
unknown
plain_text
2 years ago
2.3 kB
6
Indexable
def forward(self, X):
        N, C, H, W = X.shape
        if self.trainable:
            self.cache['X_shape'] = X.shape
            self.cache['X_strides'] = X.strides
        kernel_size, stride = self.params["kernel_size"], self.params["stride"]

        # output shape
        H_out = (H - kernel_size) // stride + 1
        W_out = (W - kernel_size) // stride + 1

        # get kernel strided X
        N_strides, C_out_strides, H_strides, W_strides = X.strides
        strided_X = np.lib.stride_tricks.as_strided(
            X,
            shape=(N, C, H_out, W_out, kernel_size, kernel_size),
            strides=(N_strides, C_out_strides, stride * H_strides, stride * W_strides, H_strides, W_strides)
        )

        # max pooling
        output = np.max(strided_X, axis=(4, 5))
        if self.trainable: 
            if self.same_kernel_stride: 
                maxes_reshaped_to_original_window = output.repeat(stride, axis=-2).repeat(stride, axis=-1)
                # pad incase of odd shape
                pad_h = H - maxes_reshaped_to_original_window.shape[-2]
                pad_w = W - maxes_reshaped_to_original_window.shape[-1]
                maxes_reshaped_to_original_window = np.pad(maxes_reshaped_to_original_window, ((0,0), (0,0), (0,pad_h), (0,pad_w)))
                self.cache['mask'] = np.equal(X, maxes_reshaped_to_original_window)
            else: self.cache['strided_X'] = strided_X
        return output

    def fully_vectorized_backward(self, dL_dy):
        # not that much increase :/
        # https://stackoverflow.com/questions/61954727/max-pooling-backpropagation-using-numpy
        stride = self.params["stride"]
        N, C, H, W = self.cache['X_shape']
        dL_dy_reshaped_to_original_window = dL_dy.repeat(stride, axis=-2).repeat(stride, axis=-1)
        
        # pad incase of odd shape
        pad_h = H - dL_dy_reshaped_to_original_window.shape[-2]
        pad_w = W - dL_dy_reshaped_to_original_window.shape[-1]
        dL_dy_reshaped_to_original_window = np.pad(dL_dy_reshaped_to_original_window, ((0,0), (0,0), (0,pad_h), (0,pad_w)))
        
        dL_dy_reshaped_to_original_window = np.multiply(dL_dy_reshaped_to_original_window, self.cache['mask'])
        return dL_dy_reshaped_to_original_window
Editor is loading...