当前位置:网站首页>Intel Distiller工具包-量化实现1
Intel Distiller工具包-量化实现1
2022-07-06 08:51:00 【cyz0202】
本系列文章
Distiller
- Distiller是Intel 2019年左右开发的一个支持神经网络压缩的工具包,支持的方法包括 剪枝、量化、蒸馏、低稚分解等;
- 本文介绍Distiller量化方案是如何实现的;由于Distiller 19年后几乎不再更新,因此主要介绍经典量化方案,用于学习;
Distiller量化实现
- 首先,我将引用Distiller examples内实现的gnmt量化代码,通过该例子介绍distiller量化框架;代码如下图
- 可以看到上述代码有3个步骤
- 收集统计数据(又称校准器,QuantCalibrationStatsCollector)
- 创建量化器(此处是后训练量化器 PostTrainLinearQuantizer)
- 量化器量化模型(prepare_model)
- 以下主要关注量化器的创建和应用(上述2/3步)
Quantizer
- distiller实现量化器的主要思路 是 module替换 法,即将要量化的模块,如conv、linear、embedding、无参ops(如加、乘、concat)等,用一个封装器(wrapper)封装起来;推断时wrapper对输入、权重(无参module没有)进行量化,然后交由被封装的真实模块(conv、linear、无参op等)进行计算,最后根据需要再做反量化;
- 注:distiller将加法(elementwise_add)、逐元素乘法(点积/elementwise_mul)、矩阵乘法(matMul)、批矩阵乘法(BatchMatMul)、concat都用nn.Module进行封装,以便distiller能用 module替换 法 进行统一量化处理;
- 现在来看一下Quantizer基类的定义,代码如下(做了注释,但读者如果没有看过distiller完整源码,还可能看不明白,推荐感兴趣读者去看看源码),主要分成如下几个部分
- 重要变量
- 量化bits设置:各module的默认设置、外部覆盖设置(overrides);定义了Qbits类
- 量化替换工厂:即replacement_factory,以dict形式记录了替换待量化module的wrapper;这个参数由具体Quantizer子类设置,下一篇文章会介绍;
- 待量化参数:params_to_quantize,记录所有待量化参数及其相关情况(所在module、量化bits等)
- 参数量化函数:param_quantization_fn,对参数进行量化的函数,由Quantizer子类设置
- 处理流程(prepare_model)
- 预处理:如BN折叠(训练和推断的BN用法不一样)、激活优化等
- 量化替换:将待量化module用相应的wrapper替换;
- 后处理
- 源码如下
class Quantizer(object): r""" Base class for quantizers. Args: model (torch.nn.Module): The model to be quantized optimizer (torch.optim.Optimizer): An optimizer instance, required in cases where the quantizer is going to perform changes to existing model parameters and/or add new ones. Specifically, when train_with_fp_copy is True, this cannot be None. bits_activations/weights/bias (int): Default number of bits to use when quantizing each tensor type. Value of None means do not quantize. overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary with overrides of default values. The keys in the overrides dictionary should be parameter names that the Quantizer accepts default values for in its init function. The parameters 'bits_activations', 'bits_weights', and 'bits_bias' which are accepted by the base Quantizer are supported by default. Other than those, each sub-class of Quantizer defines the set of parameter for which it supports over-riding. OrderedDict is used to enable handling of overlapping name patterns. So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'. The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns. train_with_fp_copy (bool): If true, will modify layers with weights to keep both a quantized and floating-point copy, such that the following flow occurs in each training iteration: 1. q_weights = quantize(fp_weights) 2. Forward through network using q_weights 3. In back-prop: 3.1 Gradients calculated with respect to q_weights 3.2 We also back-prop through the 'quantize' operation from step 1 4. Update fp_weights with gradients calculated in step 3.2 """ def __init__(self, model, optimizer=None, bits_activations=None, bits_weights=None, bits_bias=None, overrides=None, train_with_fp_copy=False): if overrides is None: overrides = OrderedDict() if not isinstance(overrides, OrderedDict): raise TypeError('overrides must be an instance of collections.OrderedDict or None') if train_with_fp_copy and optimizer is None: raise ValueError('optimizer cannot be None when train_with_fp_copy is True') # 获取计算图节点间关系,以便后续进行激活函数优化 self.adjacency_map = None # To be populated during prepare_model() # 默认的量化bits设置 self.default_qbits = QBits(acts=bits_activations, wts=bits_weights, bias=bits_bias) self.overrides = overrides self.model = model self.optimizer = optimizer # Stash some quantizer data in the model so we can re-apply the quantizer on a resuming model self.model.quantizer_metadata = {'type': type(self), 'params': {'bits_activations': bits_activations, 'bits_weights': bits_weights, 'bits_bias': bits_bias, 'overrides': copy.deepcopy(overrides)}} for k, v in self.overrides.items(): if any(old_bits_key in v.keys() for old_bits_key in ['acts', 'wts', 'bias']): raise ValueError("Using 'acts' / 'wts' / 'bias' to specify bit-width overrides is deprecated.\n" "Please use the full parameter names: " "'bits_activations' / 'bits_weights' / 'bits_bias'") qbits = QBits(acts=v.pop('bits_activations', self.default_qbits.acts), wts=v.pop('bits_weights', self.default_qbits.wts), bias=v.pop('bits_bias', self.default_qbits.bias)) v['bits'] = qbits # Prepare explicit mapping from each layer to QBits based on default + overrides patterns = [] regex_overrides = None # 需要覆盖部分module的默认量化设置 if overrides: patterns = list(overrides.keys()) regex_overrides_str = '|'.join(['(^{0}$)'.format(pattern) for pattern in patterns]) regex_overrides = re.compile(regex_overrides_str) self.module_qbits_map = {} self.module_overrides_map = {} # 设置各module的量化bits for module_full_name, module in model.named_modules(): # Need to account for scenario where model is parallelized with DataParallel, which wraps the original # module with a wrapper module called 'module' :) name_to_match = module_full_name.replace('module.', '', 1) qbits = self.default_qbits override_entry = self.overrides.get(name_to_match, OrderedDict()) if regex_overrides: m_overrides = regex_overrides.match(name_to_match) if m_overrides: group_idx = 0 groups = m_overrides.groups() while groups[group_idx] is None: group_idx += 1 override_entry = copy.deepcopy(override_entry or self.overrides[patterns[group_idx]]) qbits = override_entry.pop('bits', self.default_qbits) self._add_qbits_entry(module_full_name, type(module), qbits) self._add_override_entry(module_full_name, override_entry) # Mapping from module type to function generating a replacement module suited for quantization # To be populated by child classes # Unspecified layer types return None by default. self.replacement_factory = defaultdict(lambda: None) # Pointer to parameters quantization function, triggered during training process # To be populated by child classes self.param_quantization_fn = None # 参数量化函数 self.train_with_fp_copy = train_with_fp_copy self.params_to_quantize = [] # A dictionary of replaced modules and their respective names. self.modules_processed = OrderedDict() # 已被处理的module def _add_qbits_entry(self, module_name, module_type, qbits): if module_type not in [nn.Conv2d, nn.Conv3d, nn.Linear, nn.Embedding]: # For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't # support quantization of batch norm scale parameters) qbits = QBits(acts=qbits.acts, wts=None, bias=None) self.module_qbits_map[module_name] = qbits def _add_override_entry(self, module_name, entry): self.module_overrides_map[module_name] = entry # def prepare_model(self, dummy_input=None): """ Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width and overrides configuration provided to __init__(), and according to the replacement_factory as defined by the Quantizer sub-class being used. Note: If multiple sub-modules within the model actually reference the same module, then that module is replaced only once, according to the configuration (bit-width and/or overrides) of the first encountered reference. Toy Example - say a module is constructed using this bit of code: shared_relu = nn.ReLU self.relu1 = shared_relu self.relu2 = shared_relu When traversing the model, a replacement will be generated when 'self.relu1' is encountered. Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2' will be ignored. A warning message will be shown. """ msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__)) self.model.quantizer_metadata["dummy_input"] = dummy_input if dummy_input is not None: summary_graph = distiller.SummaryGraph(self.model, dummy_input) # 获取adjacency_map self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False) self._pre_prepare_model(dummy_input) # 预处理:BN折叠、带激活的module优化等 self._pre_process_container(self.model) # 开始执行量化module替代等主要工作 for module_name, module in self.model.named_modules(): qbits = self.module_qbits_map[module_name] curr_parameters = dict(module.named_parameters()) for param_name, param in curr_parameters.items(): n_bits = qbits.bias if param_name.endswith('bias') else qbits.wts if n_bits is None: continue fp_attr_name = param_name if self.train_with_fp_copy: hack_float_backup_parameter(module, param_name, n_bits) # 备份float参数 fp_attr_name = FP_BKP_PREFIX + param_name # 记录待量化参数的相关信息:所在module,是否有fp copy,量化设置 self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits)) param_full_name = '.'.join([module_name, param_name]) msglogger.info( "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits)) # If an optimizer was passed, assume we need to update it # 优化器可能需要更新(如新增了参数) if self.optimizer: optimizer_type = type(self.optimizer) new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults) self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups}) self._post_prepare_model() # 后处理 msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) def _pre_prepare_model(self, dummy_input): pass def _pre_process_container(self, container, prefix=''): def replace_msg(module_name, modules=None): msglogger.debug('Module ' + module_name) if modules: msglogger.debug('\tReplacing: {}.{}'.format(modules[0].__module__, modules[0].__class__.__name__)) msglogger.debug('\tWith: {}.{}'.format(modules[1].__module__, modules[1].__class__.__name__)) else: msglogger.debug('\tSkipping') # Iterate through model, insert quantization functions as appropriate # 遍历model内各个module,执行 量化模块 替代 for name, module in container.named_children(): full_name = prefix + name if module in self.modules_processed: previous_name, previous_wrapper = self.modules_processed[module] warnings.warn("Module '{0}' references to same module as '{1}'." ' Replacing with reference the same wrapper.'.format(full_name, previous_name), UserWarning) if previous_wrapper: replace_msg(full_name, (module, previous_wrapper)) setattr(container, name, previous_wrapper) else: replace_msg(full_name) continue current_qbits = self.module_qbits_map[full_name] if current_qbits.acts is None and current_qbits.wts is None: if self.module_overrides_map[full_name]: raise ValueError("Adding overrides while not quantizing is not allowed.") # We indicate this module wasn't replaced by a wrapper 不做替代 replace_msg(full_name) self.modules_processed[module] = full_name, None else: # We use a type hint comment to let IDEs know replace_fn is a function # 获取待量化module的wrapper(即replace_fn,下文介绍) replace_fn = self.replacement_factory[type(module)] # type: Optional[Callable] # If the replacement function wasn't specified - continue without replacing this module. if replace_fn is not None: valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name], replace_fn) if invalid_kwargs: raise TypeError("""Quantizer of type %s doesn't accept \"%s\" as override arguments for %s. Allowed kwargs: %s""" % (type(self), list(invalid_kwargs), type(module), list(valid_kwargs))) # 替换要量化的module为封装module new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs) replace_msg(full_name, (module, new_module)) # Add to history of prepared submodules self.modules_processed[module] = full_name, new_module setattr(container, name, new_module) # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping if not distiller.has_children(module) and distiller.has_children(new_module): for sub_module_name, sub_module in new_module.named_modules(): self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits) self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None) if distiller.has_children(module): # For container we call recursively self._pre_process_container(module, full_name + '.') def _get_updated_optimizer_params_groups(self): """ Returns a list of model parameter groups and optimizer hyper-parameter overrides, as expected by the __init__ function of torch.optim.Optimizer. This is called after all model changes were made in prepare_model, in case an Optimizer instance was passed to __init__. Subclasses which add parameters to the model should override as needed. :return: List of parameter groups """ # Default implementation - just return all model parameters as one group return [{'params': self.model.parameters()}] def _post_prepare_model(self): pass def quantize_params(self): """ Quantize all parameters using self.param_quantization_fn (with the defined number of bits for each parameter) """ for ptq in self.params_to_quantize: q_param = self.param_quantization_fn(getattr(ptq.module, ptq.fp_attr_name), ptq) if self.train_with_fp_copy: setattr(ptq.module, ptq.q_attr_name, q_param) else: getattr(ptq.module, ptq.q_attr_name).data = q_param.data
- 重要变量
小结
本文介绍了distiller及其量化功能的部分实现,主要是简单介绍了Quantizer这个基类的实现;后续具体的量化器实现均继承自该基类;
由于代码较长,考虑篇幅,具体量化器的实现将在后续文章中(Intel Distiller工具包-量化实现2)介绍;
边栏推荐
- R language uses the principal function of psych package to perform principal component analysis on the specified data set. PCA performs data dimensionality reduction (input as correlation matrix), cus
- Charging interface docking tutorial of enterprise and micro service provider platform
- Computer cleaning, deleted system files
- C語言雙指針——經典題型
- LeetCode:221. Largest Square
- LeetCode:26. 删除有序数组中的重复项
- UML diagram memory skills
- LeetCode:836. 矩形重叠
- C语言双指针——经典题型
- Double pointeur en langage C - - modèle classique
猜你喜欢
LeetCode:236. The nearest common ancestor of binary tree
SAP ui5 date type sap ui. model. type. Analysis of the parsing format of date
marathon-envs项目环境配置(强化学习模仿参考动作)
企微服务商平台收费接口对接教程
Delay initialization and sealing classes
项目连接数据库遇到的问题及解决
TCP/IP协议
Chapter 1 :Application of Artificial intelligence in Drug Design:Opportunity and Challenges
【ROS】usb_ Cam camera calibration
vb.net 随窗口改变,缩放控件大小以及保持相对位置
随机推荐
C language double pointer -- classic question type
Generator parameters incoming parameters
Leetcode: Sword Finger offer 42. Somme maximale des sous - tableaux consécutifs
超高效!Swagger-Yapi的秘密
UnsupportedOperationException异常
力扣每日一题(二)
hutool优雅解析URL链接并获取参数
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
Mobile phones and computers on the same LAN access each other, IIS settings
如何有效地进行自动化测试?
Nacos 的安装与服务的注册
数学建模2004B题(输电问题)
LeetCode:劍指 Offer 42. 連續子數組的最大和
vb.net 随窗口改变,缩放控件大小以及保持相对位置
ant-design的走马灯(Carousel)组件在TS(typescript)环境中调用prev以及next方法
多元聚类分析
[today in history] February 13: the father of transistors was born The 20th anniversary of net; Agile software development manifesto was born
swagger设置字段required必填
LeetCode:26. 删除有序数组中的重复项
角色动画(Character Animation)的现状与趋势