Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
5.6 kB
3
Indexable
Never
def prune(generator, threshold, channel_lower_bound):
    if isinstance(generator.down_sampling[2], nn.BatchNorm2d):
        norm_layer = nn.BatchNorm2d
    else:
        norm_layer = nn.InstanceNorm2d

    def prune_bn_and_conv(norm, conv, norm_layer, conv_layer, in_channels):
        out_channels = conv.in_channels - (torch.abs(norm.weight) < threshold).sum()
        out_channels = max(out_channels, channel_lower_bound)
            
        pruned_norm = norm_layer(
                        out_channels,
                        affine=norm.affine,
                        track_running_stats=norm.track_running_stats
                    )
        pruned_conv = conv_layer(
                        in_channels,
                        out_channels,
                        conv.kernel_size,
                        stride=conv.stride,
                        padding=conv.padding,
                        bias = conv.bias is not None
                    )
        return pruned_norm, pruned_conv, out_channels
        

    in_channels = generator.down_sampling[1].in_channels
    
    for i in range(2, len(generator.down_sampling), 3):
        pruned_norm, pruned_conv, out_channels = prune_bn_and_conv(
                                                    norm=generator.down_sampling[i],
                                                    conv=generator.down_sampling[i-1],
                                                    norm_layer=norm_layer, 
                                                    conv_layer=nn.Conv2d,
                                                    in_channels=in_channels
                                                )
        generator.down_sampling[i] = pruned_norm
        generator.down_sampling[i-1] = pruned_conv
        in_channels = out_channels

    for layer in generator.features:
        # out_channels = layer.pw_bn.num_features - \
        #                     (torch.abs(layer.pw_bn.weight) < threshold).sum()
        # layer.pw_bn = norm_layer(
        #                 out_channels,
        #                 affine=layer.pw_bn.affine,
        #                 track_running_stats=layer.pw_bn.track_running_stats
        #             )
        for sublayer in layer.res_ops:
            sublayer[1][1], sublayer[1][0], hid_out_channels = prune_bn_and_conv(
                                                                norm=sublayer[1][1],
                                                                conv=sublayer[1][0],
                                                                norm_layer=norm_layer, 
                                                                conv_layer=nn.Conv2d,
                                                                in_channels=in_channels
                                                            )
            sublayer[-1] = nn.Conv2d(
                            hid_out_channels,
                            sublayer[-1].out_channels,
                            sublayer[-1].kernel_size,
                            stride=sublayer[-1].stride,
                            padding=sublayer[-1].padding,
                            bias=sublayer[-1].bias is not None
                        )
                        
        for sublayer in layer.dw_ops:
            sublayer[0][1], sublayer[0][0], hid_out_channels = prune_bn_and_conv(
                                                                norm=sublayer[0][1],
                                                                conv=sublayer[0][0],
                                                                norm_layer=norm_layer, 
                                                                conv_layer=nn.Conv2d,
                                                                in_channels=in_channels
                                                            )
            
            sublayer[2][0] = nn.Conv2d(
                                hid_out_channels,
                                sublayer[2][0].out_channels,
                                sublayer[2][0].kernel_size,
                                stride=sublayer[2][0].stride,
                                padding=sublayer[2][0].padding,
                                bias=sublayer[2][0].bias is not None
                            )
        in_channels = layer.pw_bn.num_features
        
    for i in range(1, len(generator.up_sampling)-3, 3):
        pruned_norm, pruned_conv, out_channels = prune_bn_and_conv(
                                                    norm=generator.up_sampling[i],
                                                    conv=generator.up_sampling[i-1],
                                                    norm_layer=norm_layer, 
                                                    conv_layer=nn.ConvTranspose2d,
                                                    in_channels=in_channels
                                                )
        generator.up_sampling[i] = pruned_norm
        generator.up_sampling[i-1] = pruned_conv
        in_channels = out_channels
    conv = generator.up_sampling[-2]
    generator.up_sampling[-2] = nn.Conv2d(
                                    in_channels,
                                    conv.out_channels,
                                    conv.kernel_size,
                                    stride=conv.stride,
                                    padding=conv.padding,
                                    bias = conv.bias is not None
                                )
    return generator