Spaces:
Runtime error
Runtime error
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Custom version for quantized training and evaluation functions. | |
| The main difference between this and the third_party graph_rewriter_builder.py | |
| is that this version uses experimental_create_training_graph which allows the | |
| customization of freeze_bn_delay. | |
| """ | |
| import re | |
| import tensorflow.compat.v1 as tf | |
| from tensorflow.contrib import layers as contrib_layers | |
| from tensorflow.contrib import quantize as contrib_quantize | |
| from tensorflow.contrib.quantize.python import common | |
| from tensorflow.contrib.quantize.python import input_to_ops | |
| from tensorflow.contrib.quantize.python import quant_ops | |
| from tensorflow.python.ops import control_flow_ops | |
| from tensorflow.python.ops import math_ops | |
| def build(graph_rewriter_config, | |
| quant_overrides_config=None, | |
| is_training=True, | |
| is_export=False): | |
| """Returns a function that modifies default graph based on options. | |
| Args: | |
| graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto. | |
| quant_overrides_config: quant_overrides_pb2.QuantOverrides proto. | |
| is_training: whether in training or eval mode. | |
| is_export: whether exporting the graph. | |
| """ | |
| def graph_rewrite_fn(): | |
| """Function to quantize weights and activation of the default graph.""" | |
| if (graph_rewriter_config.quantization.weight_bits != 8 or | |
| graph_rewriter_config.quantization.activation_bits != 8): | |
| raise ValueError('Only 8bit quantization is supported') | |
| graph = tf.get_default_graph() | |
| # Insert custom quant ops. | |
| if quant_overrides_config is not None: | |
| input_to_ops_map = input_to_ops.InputToOps(graph) | |
| for q in quant_overrides_config.quant_configs: | |
| producer = graph.get_operation_by_name(q.op_name) | |
| if producer is None: | |
| raise ValueError('Op name does not exist in graph.') | |
| context = _get_context_from_op(producer) | |
| consumers = input_to_ops_map.ConsumerOperations(producer) | |
| if q.fixed_range: | |
| _insert_fixed_quant_op( | |
| context, | |
| q.quant_op_name, | |
| producer, | |
| consumers, | |
| init_min=q.min, | |
| init_max=q.max, | |
| quant_delay=q.delay if is_training else 0) | |
| else: | |
| raise ValueError('Learned ranges are not yet supported.') | |
| # Quantize the graph by inserting quantize ops for weights and activations | |
| if is_training: | |
| contrib_quantize.experimental_create_training_graph( | |
| input_graph=graph, | |
| quant_delay=graph_rewriter_config.quantization.delay, | |
| freeze_bn_delay=graph_rewriter_config.quantization.delay) | |
| else: | |
| contrib_quantize.experimental_create_eval_graph( | |
| input_graph=graph, | |
| quant_delay=graph_rewriter_config.quantization.delay | |
| if not is_export else 0) | |
| contrib_layers.summarize_collection('quant_vars') | |
| return graph_rewrite_fn | |
| def _get_context_from_op(op): | |
| """Gets the root context name from the op name.""" | |
| context_re = re.search(r'^(.*)/([^/]+)', op.name) | |
| if context_re: | |
| return context_re.group(1) | |
| return '' | |
| def _insert_fixed_quant_op(context, | |
| name, | |
| producer, | |
| consumers, | |
| init_min=-6.0, | |
| init_max=6.0, | |
| quant_delay=None): | |
| """Adds a fake quant op with fixed ranges. | |
| Args: | |
| context: The parent scope of the op to be quantized. | |
| name: The name of the fake quant op. | |
| producer: The producer op to be quantized. | |
| consumers: The consumer ops to the producer op. | |
| init_min: The minimum range for the fake quant op. | |
| init_max: The maximum range for the fake quant op. | |
| quant_delay: Number of steps to wait before activating the fake quant op. | |
| Raises: | |
| ValueError: When producer operation is not directly connected to the | |
| consumer operation. | |
| """ | |
| name_prefix = name if not context else context + '/' + name | |
| inputs = producer.outputs[0] | |
| quant = quant_ops.FixedQuantize( | |
| inputs, init_min=init_min, init_max=init_max, scope=name_prefix) | |
| if quant_delay and quant_delay > 0: | |
| activate_quant = math_ops.greater_equal( | |
| common.CreateOrGetQuantizationStep(), | |
| quant_delay, | |
| name=name_prefix + '/activate_quant') | |
| quant = control_flow_ops.cond( | |
| activate_quant, | |
| lambda: quant, | |
| lambda: inputs, | |
| name=name_prefix + '/delayed_quant') | |
| if consumers: | |
| tensors_modified_count = common.RerouteTensor( | |
| quant, inputs, can_modify=consumers) | |
| # Some operations can have multiple output tensors going to the same | |
| # consumer. Since consumers is a set, we need to ensure that | |
| # tensors_modified_count is greater than or equal to the length of the set | |
| # of consumers. | |
| if tensors_modified_count < len(consumers): | |
| raise ValueError('No inputs quantized for ops: [%s]' % ', '.join( | |
| [consumer.name for consumer in consumers])) | |