Untitled
unknown
python
3 years ago
5.6 kB
6
Indexable
def prune(generator, threshold, channel_lower_bound): if isinstance(generator.down_sampling[2], nn.BatchNorm2d): norm_layer = nn.BatchNorm2d else: norm_layer = nn.InstanceNorm2d def prune_bn_and_conv(norm, conv, norm_layer, conv_layer, in_channels): out_channels = conv.in_channels - (torch.abs(norm.weight) < threshold).sum() out_channels = max(out_channels, channel_lower_bound) pruned_norm = norm_layer( out_channels, affine=norm.affine, track_running_stats=norm.track_running_stats ) pruned_conv = conv_layer( in_channels, out_channels, conv.kernel_size, stride=conv.stride, padding=conv.padding, bias = conv.bias is not None ) return pruned_norm, pruned_conv, out_channels in_channels = generator.down_sampling[1].in_channels for i in range(2, len(generator.down_sampling), 3): pruned_norm, pruned_conv, out_channels = prune_bn_and_conv( norm=generator.down_sampling[i], conv=generator.down_sampling[i-1], norm_layer=norm_layer, conv_layer=nn.Conv2d, in_channels=in_channels ) generator.down_sampling[i] = pruned_norm generator.down_sampling[i-1] = pruned_conv in_channels = out_channels for layer in generator.features: # out_channels = layer.pw_bn.num_features - \ # (torch.abs(layer.pw_bn.weight) < threshold).sum() # layer.pw_bn = norm_layer( # out_channels, # affine=layer.pw_bn.affine, # track_running_stats=layer.pw_bn.track_running_stats # ) for sublayer in layer.res_ops: sublayer[1][1], sublayer[1][0], hid_out_channels = prune_bn_and_conv( norm=sublayer[1][1], conv=sublayer[1][0], norm_layer=norm_layer, conv_layer=nn.Conv2d, in_channels=in_channels ) sublayer[-1] = nn.Conv2d( hid_out_channels, sublayer[-1].out_channels, sublayer[-1].kernel_size, stride=sublayer[-1].stride, padding=sublayer[-1].padding, bias=sublayer[-1].bias is not None ) for sublayer in layer.dw_ops: sublayer[0][1], sublayer[0][0], hid_out_channels = prune_bn_and_conv( norm=sublayer[0][1], conv=sublayer[0][0], norm_layer=norm_layer, conv_layer=nn.Conv2d, in_channels=in_channels ) sublayer[2][0] = nn.Conv2d( hid_out_channels, sublayer[2][0].out_channels, sublayer[2][0].kernel_size, stride=sublayer[2][0].stride, padding=sublayer[2][0].padding, bias=sublayer[2][0].bias is not None ) in_channels = layer.pw_bn.num_features for i in range(1, len(generator.up_sampling)-3, 3): pruned_norm, pruned_conv, out_channels = prune_bn_and_conv( norm=generator.up_sampling[i], conv=generator.up_sampling[i-1], norm_layer=norm_layer, conv_layer=nn.ConvTranspose2d, in_channels=in_channels ) generator.up_sampling[i] = pruned_norm generator.up_sampling[i-1] = pruned_conv in_channels = out_channels conv = generator.up_sampling[-2] generator.up_sampling[-2] = nn.Conv2d( in_channels, conv.out_channels, conv.kernel_size, stride=conv.stride, padding=conv.padding, bias = conv.bias is not None ) return generator
Editor is loading...