Untitled
unknown
plain_text
3 years ago
11 kB
5
Indexable
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#
from genericpath import isdir
from torch import nn, Tensor
from typing import Optional, Dict
import argparse
from utils import logger
# For visualzation
import os
import pickle
GET_MIDDEN_LAYER_FUNCTURE = True
DIR_NAME = "Extraction"
MODEL_NAME = f"{DIR_NAME}/Mobilevit"
LAYER_ID = 0
from ... import parameter_list
from ...layers import norm_layers_tuple
from ...misc.profiler import module_profile
from ...misc.init_utils import initialize_weights
class BaseEncoder(nn.Module):
def __init__(self, *args, **kwargs):
super(BaseEncoder, self).__init__()
self.conv_1 = None
self.layer_1 = None
self.layer_2 = None
self.layer_3 = None
self.layer_4 = None
self.layer_5 = None
self.conv_1x1_exp = None
self.classifier = None
self.round_nearest = 8
self.model_conf_dict = dict()
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
return parser
def check_model(self):
assert self.model_conf_dict, "Model configuration dictionary should not be empty"
assert self.conv_1 is not None, 'Please implement self.conv_1'
assert self.layer_1 is not None, 'Please implement self.layer_1'
assert self.layer_2 is not None, 'Please implement self.layer_2'
assert self.layer_3 is not None, 'Please implement self.layer_3'
assert self.layer_4 is not None, 'Please implement self.layer_4'
assert self.layer_5 is not None, 'Please implement self.layer_5'
assert self.conv_1x1_exp is not None, 'Please implement self.conv_1x1_exp'
assert self.classifier is not None, 'Please implement self.classifier'
def reset_parameters(self, opts):
initialize_weights(opts=opts, modules=self.modules())
def extract_end_points_all(self, x: Tensor, use_l5: Optional[bool] = True, use_l5_exp: Optional[bool] = False) -> Dict:
out_dict = {} # Use dictionary over NamedTuple so that JIT is happy
x = self.conv_1(x) # 112 x112
x = self.layer_1(x) # 112 x112
out_dict["out_l1"] = x
x = self.layer_2(x) # 56 x 56
out_dict["out_l2"] = x
x = self.layer_3(x) # 28 x 28
out_dict["out_l3"] = x
x = self.layer_4(x) # 14 x 14
out_dict["out_l4"] = x
if use_l5:
x = self.layer_5(x) # 7 x 7
out_dict["out_l5"] = x
if use_l5_exp:
x = self.conv_1x1_exp(x)
out_dict["out_l5_exp"] = x
return out_dict
def extract_end_points_l4(self, x: Tensor) -> Dict:
return self.extract_end_points_all(x, use_l5=False)
def extract_features(self, x: Tensor) -> Tensor:
if GET_MIDDEN_LAYER_FUNCTURE and not os.path.isdir(DIR_NAME):
os.mkdir(DIR_NAME)
x = self.conv_1(x)
global LAYER_ID
if GET_MIDDEN_LAYER_FUNCTURE:
with open(f"{MODEL_NAME}_{LAYER_ID}_conv_1.pkl", "wb") as f:
pickle.dump(x.cpu().numpy(), f)
LAYER_ID += 1
x = self.layer_1(x)
if GET_MIDDEN_LAYER_FUNCTURE:
with open(f"{MODEL_NAME}_2_layer_1.pkl", "wb") as f:
pickle.dump(x.cpu().numpy(), f)
x = self.layer_2(x)
if GET_MIDDEN_LAYER_FUNCTURE:
with open(f"{MODEL_NAME}_3_layer_2.pkl", "wb") as f:
pickle.dump(x.cpu().numpy(), f)
x = self.layer_3(x)
if GET_MIDDEN_LAYER_FUNCTURE:
with open(f"{MODEL_NAME}_4_layer_3.pkl", "wb") as f:
pickle.dump(x.cpu().numpy(), f)
x = self.layer_4(x)
if GET_MIDDEN_LAYER_FUNCTURE:
with open(f"{MODEL_NAME}_5_layer_4.pkl", "wb") as f:
pickle.dump(x.cpu().numpy(), f)
x = self.layer_5(x)
if GET_MIDDEN_LAYER_FUNCTURE:
with open(f"{MODEL_NAME}_6_layer_5.pkl", "wb") as f:
pickle.dump(x.cpu().numpy(), f)
x = self.conv_1x1_exp(x)
if GET_MIDDEN_LAYER_FUNCTURE:
with open(f"{MODEL_NAME}_7_conv_1x1_exp.pkl", "wb") as f:
pickle.dump(x.cpu().numpy(), f)
return x
def forward(self, x: Tensor) -> Tensor:
x = self.extract_features(x)
x = self.classifier(x)
global GET_MIDDEN_LAYER_FUNCTURE
if GET_MIDDEN_LAYER_FUNCTURE:
with open(f"{MODEL_NAME}_8_classifier.pkl", "wb") as f:
pickle.dump(x.cpu().numpy(), f)
# Only extract once
GET_MIDDEN_LAYER_FUNCTURE = False
return x
def freeze_norm_layers(self):
for m in self.modules():
if isinstance(m, norm_layers_tuple):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False
m.training = False
def get_trainable_parameters(self, weight_decay: float = 0.0, no_decay_bn_filter_bias: bool = False):
param_list = parameter_list(named_parameters=self.named_parameters,
weight_decay=weight_decay,
no_decay_bn_filter_bias=no_decay_bn_filter_bias)
return param_list, [1.0] * len(param_list)
@staticmethod
def _profile_layers(layers, input, overall_params, overall_macs):
if not isinstance(layers, list):
layers = [layers]
for layer in layers:
if layer is None:
continue
input, layer_param, layer_macs = module_profile(module=layer, x=input)
overall_params += layer_param
overall_macs += layer_macs
if isinstance(layer, nn.Sequential):
module_name = "\n+".join([l.__class__.__name__ for l in layer])
else:
module_name = layer.__class__.__name__
print(
'{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M'.format(module_name,
'Params',
round(layer_param / 1e6, 3),
'MACs',
round(layer_macs / 1e6, 3)
))
logger.singe_dash_line()
return input, overall_params, overall_macs
def profile_model(self, input: Tensor, is_classification: bool = True) -> (Tensor or Dict[Tensor], float, float):
# Note: Model profiling is for reference only and may contain errors.
# It relies heavily on the user to implement the underlying functions accurately.
overall_params, overall_macs = 0.0, 0.0
if is_classification:
logger.log('Model statistics for an input of size {}'.format(input.size()))
logger.double_dash_line(dashes=65)
print('{:>35} Summary'.format(self.__class__.__name__))
logger.double_dash_line(dashes=65)
out_dict = {}
input, overall_params, overall_macs = self._profile_layers([self.conv_1, self.layer_1], input=input, overall_params=overall_params, overall_macs=overall_macs)
out_dict["out_l1"] = input
input, overall_params, overall_macs = self._profile_layers(self.layer_2, input=input,
overall_params=overall_params,
overall_macs=overall_macs)
out_dict["out_l2"] = input
input, overall_params, overall_macs = self._profile_layers(self.layer_3, input=input,
overall_params=overall_params,
overall_macs=overall_macs)
out_dict["out_l3"] = input
input, overall_params, overall_macs = self._profile_layers(self.layer_4, input=input,
overall_params=overall_params,
overall_macs=overall_macs)
out_dict["out_l4"] = input
input, overall_params, overall_macs = self._profile_layers(self.layer_5, input=input,
overall_params=overall_params,
overall_macs=overall_macs)
out_dict["out_l5"] = input
if self.conv_1x1_exp is not None:
input, overall_params, overall_macs = self._profile_layers(self.conv_1x1_exp, input=input,
overall_params=overall_params,
overall_macs=overall_macs)
out_dict["out_l5_exp"] = input
if is_classification:
classifier_params, classifier_macs = 0.0, 0.0
if self.classifier is not None:
input, classifier_params, classifier_macs = module_profile(module=self.classifier, x=input)
print('{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M'.format('Classifier',
'Params',
round(classifier_params / 1e6, 3),
'MACs',
round(classifier_macs / 1e6, 3)))
overall_params += classifier_params
overall_macs += classifier_macs
logger.double_dash_line(dashes=65)
print('{:<20} = {:>8.3f} M'.format('Overall parameters', overall_params / 1e6))
# Counting Addition and Multiplication as 1 operation
print('{:<20} = {:>8.3f} M'.format('Overall MACs', overall_macs / 1e6))
overall_params_py = sum([p.numel() for p in self.parameters()])
print('{:<20} = {:>8.3f} M'.format('Overall parameters (sanity check)', overall_params_py / 1e6))
logger.double_dash_line(dashes=65)
return out_dict, overall_params, overall_macs
Editor is loading...