Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
11 kB
1
Indexable
Never
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#

from genericpath import isdir
from torch import nn, Tensor
from typing import Optional, Dict
import argparse

from utils import logger

# For visualzation
import os
import pickle
GET_MIDDEN_LAYER_FUNCTURE = True
DIR_NAME = "Extraction"
MODEL_NAME = f"{DIR_NAME}/Mobilevit"
LAYER_ID = 0

from ... import parameter_list
from ...layers import norm_layers_tuple
from ...misc.profiler import module_profile
from ...misc.init_utils import initialize_weights


class BaseEncoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super(BaseEncoder, self).__init__()
        self.conv_1 = None
        self.layer_1 = None
        self.layer_2 = None
        self.layer_3 = None
        self.layer_4 = None
        self.layer_5 = None
        self.conv_1x1_exp = None
        self.classifier = None
        self.round_nearest = 8

        self.model_conf_dict = dict()

    @classmethod
    def add_arguments(cls, parser: argparse.ArgumentParser):
        return parser

    def check_model(self):
        assert self.model_conf_dict, "Model configuration dictionary should not be empty"
        assert self.conv_1 is not None, 'Please implement self.conv_1'
        assert self.layer_1 is not None, 'Please implement self.layer_1'
        assert self.layer_2 is not None, 'Please implement self.layer_2'
        assert self.layer_3 is not None, 'Please implement self.layer_3'
        assert self.layer_4 is not None, 'Please implement self.layer_4'
        assert self.layer_5 is not None, 'Please implement self.layer_5'
        assert self.conv_1x1_exp is not None, 'Please implement self.conv_1x1_exp'
        assert self.classifier is not None, 'Please implement self.classifier'

    def reset_parameters(self, opts):
        initialize_weights(opts=opts, modules=self.modules())

    def extract_end_points_all(self, x: Tensor, use_l5: Optional[bool] = True, use_l5_exp: Optional[bool] = False) -> Dict:
        out_dict = {} # Use dictionary over NamedTuple so that JIT is happy
        x = self.conv_1(x)  # 112 x112
        x = self.layer_1(x)  # 112 x112
        out_dict["out_l1"] = x

        x = self.layer_2(x)  # 56 x 56
        out_dict["out_l2"] = x

        x = self.layer_3(x)  # 28 x 28
        out_dict["out_l3"] = x

        x = self.layer_4(x)  # 14 x 14
        out_dict["out_l4"] = x

        if use_l5:
            x = self.layer_5(x)  # 7 x 7
            out_dict["out_l5"] = x

            if use_l5_exp:
                x = self.conv_1x1_exp(x)
                out_dict["out_l5_exp"] = x
        return out_dict

    def extract_end_points_l4(self, x: Tensor) -> Dict:
        return self.extract_end_points_all(x, use_l5=False)

    def extract_features(self, x: Tensor) -> Tensor:

        if GET_MIDDEN_LAYER_FUNCTURE and not os.path.isdir(DIR_NAME):
            os.mkdir(DIR_NAME)

        x = self.conv_1(x)
        global LAYER_ID
        if GET_MIDDEN_LAYER_FUNCTURE:
            with open(f"{MODEL_NAME}_{LAYER_ID}_conv_1.pkl", "wb") as f:
                pickle.dump(x.cpu().numpy(), f)
            LAYER_ID += 1

        x = self.layer_1(x)
        if GET_MIDDEN_LAYER_FUNCTURE:
            with open(f"{MODEL_NAME}_2_layer_1.pkl", "wb") as f:
                pickle.dump(x.cpu().numpy(), f)
        
        x = self.layer_2(x)
        if GET_MIDDEN_LAYER_FUNCTURE:
            with open(f"{MODEL_NAME}_3_layer_2.pkl", "wb") as f:
                pickle.dump(x.cpu().numpy(), f)
        x = self.layer_3(x)
        if GET_MIDDEN_LAYER_FUNCTURE:
            with open(f"{MODEL_NAME}_4_layer_3.pkl", "wb") as f:
                pickle.dump(x.cpu().numpy(), f)

        x = self.layer_4(x)
        if GET_MIDDEN_LAYER_FUNCTURE:
            with open(f"{MODEL_NAME}_5_layer_4.pkl", "wb") as f:
                pickle.dump(x.cpu().numpy(), f)
        
        x = self.layer_5(x)
        if GET_MIDDEN_LAYER_FUNCTURE:
            with open(f"{MODEL_NAME}_6_layer_5.pkl", "wb") as f:
                pickle.dump(x.cpu().numpy(), f)
        
        x = self.conv_1x1_exp(x)
        if GET_MIDDEN_LAYER_FUNCTURE:
            with open(f"{MODEL_NAME}_7_conv_1x1_exp.pkl", "wb") as f:
                pickle.dump(x.cpu().numpy(), f)
        
        return x

    def forward(self, x: Tensor) -> Tensor:
        x = self.extract_features(x)
        x = self.classifier(x)
        global GET_MIDDEN_LAYER_FUNCTURE
        if GET_MIDDEN_LAYER_FUNCTURE:
            with open(f"{MODEL_NAME}_8_classifier.pkl", "wb") as f:
                pickle.dump(x.cpu().numpy(), f)
        
        # Only extract once
        GET_MIDDEN_LAYER_FUNCTURE = False
        return x

    def freeze_norm_layers(self):
        for m in self.modules():
            if isinstance(m, norm_layers_tuple):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False
                m.training = False

    def get_trainable_parameters(self, weight_decay: float = 0.0, no_decay_bn_filter_bias: bool = False):
        param_list = parameter_list(named_parameters=self.named_parameters,
                                    weight_decay=weight_decay,
                                    no_decay_bn_filter_bias=no_decay_bn_filter_bias)
        return param_list, [1.0] * len(param_list)

    @staticmethod
    def _profile_layers(layers, input, overall_params, overall_macs):
        if not isinstance(layers, list):
            layers = [layers]

        for layer in layers:
            if layer is None:
                continue
            input, layer_param, layer_macs = module_profile(module=layer, x=input)

            overall_params += layer_param
            overall_macs += layer_macs

            if isinstance(layer, nn.Sequential):
                module_name = "\n+".join([l.__class__.__name__ for l in layer])
            else:
                module_name = layer.__class__.__name__
            print(
                '{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M'.format(module_name,
                                                                          'Params',
                                                                          round(layer_param / 1e6, 3),
                                                                          'MACs',
                                                                          round(layer_macs / 1e6, 3)
                                                                          ))
            logger.singe_dash_line()
        return input, overall_params, overall_macs

    def profile_model(self, input: Tensor, is_classification: bool = True) -> (Tensor or Dict[Tensor], float, float):
        # Note: Model profiling is for reference only and may contain errors.
        # It relies heavily on the user to implement the underlying functions accurately.
        overall_params, overall_macs = 0.0, 0.0

        if is_classification:
            logger.log('Model statistics for an input of size {}'.format(input.size()))
            logger.double_dash_line(dashes=65)
            print('{:>35} Summary'.format(self.__class__.__name__))
            logger.double_dash_line(dashes=65)

        out_dict = {}
        input, overall_params, overall_macs = self._profile_layers([self.conv_1, self.layer_1], input=input, overall_params=overall_params, overall_macs=overall_macs)
        out_dict["out_l1"] = input

        input, overall_params, overall_macs = self._profile_layers(self.layer_2, input=input,
                                                                   overall_params=overall_params,
                                                                   overall_macs=overall_macs)
        out_dict["out_l2"] = input

        input, overall_params, overall_macs = self._profile_layers(self.layer_3, input=input,
                                                                   overall_params=overall_params,
                                                                   overall_macs=overall_macs)
        out_dict["out_l3"] = input

        input, overall_params, overall_macs = self._profile_layers(self.layer_4, input=input,
                                                                   overall_params=overall_params,
                                                                   overall_macs=overall_macs)
        out_dict["out_l4"] = input

        input, overall_params, overall_macs = self._profile_layers(self.layer_5, input=input,
                                                                   overall_params=overall_params,
                                                                   overall_macs=overall_macs)
        out_dict["out_l5"] = input

        if self.conv_1x1_exp is not None:
            input, overall_params, overall_macs = self._profile_layers(self.conv_1x1_exp, input=input,
                                                                       overall_params=overall_params,
                                                                       overall_macs=overall_macs)
            out_dict["out_l5_exp"] = input

        if is_classification:
            classifier_params, classifier_macs = 0.0, 0.0
            if self.classifier is not None:
                input, classifier_params, classifier_macs = module_profile(module=self.classifier, x=input)
                print('{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M'.format('Classifier',
                                                                                'Params',
                                                                                round(classifier_params / 1e6, 3),
                                                                                'MACs',
                                                                                round(classifier_macs / 1e6, 3)))
            overall_params += classifier_params
            overall_macs += classifier_macs

            logger.double_dash_line(dashes=65)
            print('{:<20} = {:>8.3f} M'.format('Overall parameters', overall_params / 1e6))
            # Counting Addition and Multiplication as 1 operation
            print('{:<20} = {:>8.3f} M'.format('Overall MACs', overall_macs / 1e6))
            overall_params_py = sum([p.numel() for p in self.parameters()])
            print('{:<20} = {:>8.3f} M'.format('Overall parameters (sanity check)', overall_params_py / 1e6))
            logger.double_dash_line(dashes=65)

        return out_dict, overall_params, overall_macs