Untitled
unknown
python
4 years ago
5.6 kB
7
Indexable
def prune(generator, threshold, channel_lower_bound):
if isinstance(generator.down_sampling[2], nn.BatchNorm2d):
norm_layer = nn.BatchNorm2d
else:
norm_layer = nn.InstanceNorm2d
def prune_bn_and_conv(norm, conv, norm_layer, conv_layer, in_channels):
out_channels = conv.in_channels - (torch.abs(norm.weight) < threshold).sum()
out_channels = max(out_channels, channel_lower_bound)
pruned_norm = norm_layer(
out_channels,
affine=norm.affine,
track_running_stats=norm.track_running_stats
)
pruned_conv = conv_layer(
in_channels,
out_channels,
conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias = conv.bias is not None
)
return pruned_norm, pruned_conv, out_channels
in_channels = generator.down_sampling[1].in_channels
for i in range(2, len(generator.down_sampling), 3):
pruned_norm, pruned_conv, out_channels = prune_bn_and_conv(
norm=generator.down_sampling[i],
conv=generator.down_sampling[i-1],
norm_layer=norm_layer,
conv_layer=nn.Conv2d,
in_channels=in_channels
)
generator.down_sampling[i] = pruned_norm
generator.down_sampling[i-1] = pruned_conv
in_channels = out_channels
for layer in generator.features:
# out_channels = layer.pw_bn.num_features - \
# (torch.abs(layer.pw_bn.weight) < threshold).sum()
# layer.pw_bn = norm_layer(
# out_channels,
# affine=layer.pw_bn.affine,
# track_running_stats=layer.pw_bn.track_running_stats
# )
for sublayer in layer.res_ops:
sublayer[1][1], sublayer[1][0], hid_out_channels = prune_bn_and_conv(
norm=sublayer[1][1],
conv=sublayer[1][0],
norm_layer=norm_layer,
conv_layer=nn.Conv2d,
in_channels=in_channels
)
sublayer[-1] = nn.Conv2d(
hid_out_channels,
sublayer[-1].out_channels,
sublayer[-1].kernel_size,
stride=sublayer[-1].stride,
padding=sublayer[-1].padding,
bias=sublayer[-1].bias is not None
)
for sublayer in layer.dw_ops:
sublayer[0][1], sublayer[0][0], hid_out_channels = prune_bn_and_conv(
norm=sublayer[0][1],
conv=sublayer[0][0],
norm_layer=norm_layer,
conv_layer=nn.Conv2d,
in_channels=in_channels
)
sublayer[2][0] = nn.Conv2d(
hid_out_channels,
sublayer[2][0].out_channels,
sublayer[2][0].kernel_size,
stride=sublayer[2][0].stride,
padding=sublayer[2][0].padding,
bias=sublayer[2][0].bias is not None
)
in_channels = layer.pw_bn.num_features
for i in range(1, len(generator.up_sampling)-3, 3):
pruned_norm, pruned_conv, out_channels = prune_bn_and_conv(
norm=generator.up_sampling[i],
conv=generator.up_sampling[i-1],
norm_layer=norm_layer,
conv_layer=nn.ConvTranspose2d,
in_channels=in_channels
)
generator.up_sampling[i] = pruned_norm
generator.up_sampling[i-1] = pruned_conv
in_channels = out_channels
conv = generator.up_sampling[-2]
generator.up_sampling[-2] = nn.Conv2d(
in_channels,
conv.out_channels,
conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias = conv.bias is not None
)
return generator
Editor is loading...