my_module
unknown
python
3 years ago
811 B
8
Indexable
class My_d(nn.Module):
r""" This will apply dropout with prob meta_prob """
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(My_d, self).__init__()
self.meta_prob = nn.Parameter(th.tensor(0.5))
def forward(self, input: th.Tensor) -> th.Tensor:
mask = th.empty(input.shape,device=input.device)
mask.bernoulli_(self.meta_prob)
return input * maskEditor is loading...