my_module

mail@pastecode.io avatar
unknown
python
2 years ago
811 B
5
Indexable
Never
class My_d(nn.Module):                                                          
    r""" This will apply dropout with prob meta_prob """                                                                        
    def __init__(self, *args: Any, **kwargs: Any) -> None:                      
        super(My_d, self).__init__()                                           
        self.meta_prob = nn.Parameter(th.tensor(0.5))                              
                                                                               
    def forward(self, input: th.Tensor) -> th.Tensor:                              
        mask = th.empty(input.shape,device=input.device)                        
        mask.bernoulli_(self.meta_prob)                                            
        return input * mask