my_module
unknown
python
3 years ago
811 B
7
Indexable
class My_d(nn.Module): r""" This will apply dropout with prob meta_prob """ def __init__(self, *args: Any, **kwargs: Any) -> None: super(My_d, self).__init__() self.meta_prob = nn.Parameter(th.tensor(0.5)) def forward(self, input: th.Tensor) -> th.Tensor: mask = th.empty(input.shape,device=input.device) mask.bernoulli_(self.meta_prob) return input * mask
Editor is loading...