Yuantao Feng
commited on
Commit
·
55654c5
1
Parent(s):
2074f99
Add tools for quantization and quantized models (#36)
Browse files* add scripts for quantization
* update path to pp-resnet50
* add quantized models
* rename dict to models
* add requirements and readme
* fix typos
- tools/quantize/README.md +38 -0
- tools/quantize/quantize.py +116 -0
- tools/quantize/requirements.txt +3 -0
- tools/quantize/transform.py +32 -0
tools/quantize/README.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quantization with ONNXRUNTIME
|
| 2 |
+
|
| 3 |
+
ONNXRUNTIME is used for quantization in the Zoo.
|
| 4 |
+
|
| 5 |
+
Install dependencies before trying quantization:
|
| 6 |
+
```shell
|
| 7 |
+
pip install -r requirements.txt
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
## Usage
|
| 11 |
+
|
| 12 |
+
Quantize all models in the Zoo:
|
| 13 |
+
```shell
|
| 14 |
+
python quantize.py
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Quantize one of the models in the Zoo:
|
| 18 |
+
```shell
|
| 19 |
+
# python quantize.py <key_in_models>
|
| 20 |
+
python quantize.py yunet
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Customizing quantization configs:
|
| 24 |
+
```python
|
| 25 |
+
# add model into `models` dict in quantize.py
|
| 26 |
+
models = dict(
|
| 27 |
+
# ...
|
| 28 |
+
model1=Quantize(model_path='/path/to/model1.onnx'
|
| 29 |
+
calibration_image_dir='/path/to/images',
|
| 30 |
+
transforms=Compose([''' transforms ''']), # transforms can be found in transforms.py
|
| 31 |
+
per_channel=False, # set False to quantize in per-tensor style
|
| 32 |
+
act_type='int8', # available types: 'int8', 'uint8'
|
| 33 |
+
wt_type='int8' # available types: 'int8', 'uint8'
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
# quantize the added models
|
| 37 |
+
python quantize.py model1
|
| 38 |
+
```
|
tools/quantize/quantize.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is part of OpenCV Zoo project.
|
| 2 |
+
# It is subject to the license terms in the LICENSE file found in the same directory.
|
| 3 |
+
#
|
| 4 |
+
# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
|
| 5 |
+
# Third party copyrights are property of their respective owners.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import numpy as ny
|
| 10 |
+
import cv2 as cv
|
| 11 |
+
|
| 12 |
+
import onnx
|
| 13 |
+
from onnx import version_converter
|
| 14 |
+
import onnxruntime
|
| 15 |
+
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType
|
| 16 |
+
|
| 17 |
+
from transform import Compose, Resize, ColorConvert
|
| 18 |
+
|
| 19 |
+
class DataReader(CalibrationDataReader):
|
| 20 |
+
def __init__(self, model_path, image_dir, transforms):
|
| 21 |
+
model = onnx.load(model_path)
|
| 22 |
+
self.input_name = model.graph.input[0].name
|
| 23 |
+
self.transforms = transforms
|
| 24 |
+
self.data = self.get_calibration_data(image_dir)
|
| 25 |
+
self.enum_data_dicts = iter([{self.input_name: x} for x in self.data])
|
| 26 |
+
|
| 27 |
+
def get_next(self):
|
| 28 |
+
return next(self.enum_data_dicts, None)
|
| 29 |
+
|
| 30 |
+
def get_calibration_data(self, image_dir):
|
| 31 |
+
blobs = []
|
| 32 |
+
for image_name in os.listdir(image_dir):
|
| 33 |
+
if not image_name.endswith('jpg'):
|
| 34 |
+
continue
|
| 35 |
+
img = cv.imread(os.path.join(image_dir, image_name))
|
| 36 |
+
img = self.transforms(img)
|
| 37 |
+
blob = cv.dnn.blobFromImage(img)
|
| 38 |
+
blobs.append(blob)
|
| 39 |
+
return blobs
|
| 40 |
+
|
| 41 |
+
class Quantize:
|
| 42 |
+
def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8'):
|
| 43 |
+
self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}
|
| 44 |
+
|
| 45 |
+
self.model_path = model_path
|
| 46 |
+
self.calibration_image_dir = calibration_image_dir
|
| 47 |
+
self.transforms = transforms
|
| 48 |
+
self.per_channel = per_channel
|
| 49 |
+
self.act_type = act_type
|
| 50 |
+
self.wt_type = wt_type
|
| 51 |
+
|
| 52 |
+
# data reader
|
| 53 |
+
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms)
|
| 54 |
+
|
| 55 |
+
def check_opset(self, convert=True):
|
| 56 |
+
model = onnx.load(self.model_path)
|
| 57 |
+
if model.opset_import[0].version != 11:
|
| 58 |
+
print('\tmodel opset version: {}. Converting to opset 11'.format(model.opset_import[0].version))
|
| 59 |
+
# convert opset version to 11
|
| 60 |
+
model_opset11 = version_converter.convert_version(model, 11)
|
| 61 |
+
# save converted model
|
| 62 |
+
output_name = '{}-opset11.onnx'.format(self.model_path[:-5])
|
| 63 |
+
onnx.save_model(model_opset11, output_name)
|
| 64 |
+
# update model_path for quantization
|
| 65 |
+
self.model_path = output_name
|
| 66 |
+
|
| 67 |
+
def run(self):
|
| 68 |
+
print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
|
| 69 |
+
self.check_opset()
|
| 70 |
+
output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
|
| 71 |
+
quantize_static(self.model_path, output_name, self.dr,
|
| 72 |
+
per_channel=self.per_channel,
|
| 73 |
+
weight_type=self.type_dict[self.wt_type],
|
| 74 |
+
activation_type=self.type_dict[self.act_type])
|
| 75 |
+
os.remove('augmented_model.onnx')
|
| 76 |
+
os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
|
| 77 |
+
print('\tQuantized model saved to {}'.format(output_name))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
models=dict(
|
| 81 |
+
yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2021dec.onnx',
|
| 82 |
+
calibration_image_dir='../../benchmark/data/face_detection'),
|
| 83 |
+
sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx',
|
| 84 |
+
calibration_image_dir='../../benchmark/data/face_recognition',
|
| 85 |
+
transforms=Compose([Resize(size=(112, 112))])),
|
| 86 |
+
pphumenseg=Quantize(model_path='../../models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2021oct.onnx',
|
| 87 |
+
calibration_image_dir='../../benchmark/data/human_segmentation',
|
| 88 |
+
transforms=Compose([Resize(size=(192, 192))])),
|
| 89 |
+
ppresnet50=Quantize(model_path='../../models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx',
|
| 90 |
+
calibration_image_dir='../../benchmark/data/image_classification',
|
| 91 |
+
transforms=Compose([Resize(size=(224, 224))])),
|
| 92 |
+
# TBD: DaSiamRPN
|
| 93 |
+
youtureid=Quantize(model_path='../../models/person_reid_youtureid/person_reid_youtu_2021nov.onnx',
|
| 94 |
+
calibration_image_dir='../../benchmark/data/person_reid',
|
| 95 |
+
transforms=Compose([Resize(size=(128, 256))])),
|
| 96 |
+
# TBD: DB-EN & DB-CN
|
| 97 |
+
crnn_en=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx',
|
| 98 |
+
calibration_image_dir='../../benchmark/data/text',
|
| 99 |
+
transforms=Compose([Resize(size=(100, 32)), ColorConvert(ctype=cv.COLOR_BGR2GRAY)])),
|
| 100 |
+
crnn_cn=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_CN_2021nov.onnx',
|
| 101 |
+
calibration_image_dir='../../benchmark/data/text',
|
| 102 |
+
transforms=Compose([Resize(size=(100, 32))]))
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if __name__ == '__main__':
|
| 106 |
+
selected_models = []
|
| 107 |
+
for i in range(1, len(sys.argv)):
|
| 108 |
+
selected_models.append(sys.argv[i])
|
| 109 |
+
if not selected_models:
|
| 110 |
+
selected_models = list(models.keys())
|
| 111 |
+
print('Models to be quantized: {}'.format(str(selected_models)))
|
| 112 |
+
|
| 113 |
+
for selected_model_name in selected_models:
|
| 114 |
+
q = models[selected_model_name]
|
| 115 |
+
q.run()
|
| 116 |
+
|
tools/quantize/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
opencv-python>=4.5.4.58
|
| 2 |
+
onnx
|
| 3 |
+
onnxruntime
|
tools/quantize/transform.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is part of OpenCV Zoo project.
|
| 2 |
+
# It is subject to the license terms in the LICENSE file found in the same directory.
|
| 3 |
+
#
|
| 4 |
+
# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
|
| 5 |
+
# Third party copyrights are property of their respective owners.
|
| 6 |
+
|
| 7 |
+
import numpy as numpy
|
| 8 |
+
import cv2 as cv
|
| 9 |
+
|
| 10 |
+
class Compose:
|
| 11 |
+
def __init__(self, transforms=[]):
|
| 12 |
+
self.transforms = transforms
|
| 13 |
+
|
| 14 |
+
def __call__(self, img):
|
| 15 |
+
for t in self.transforms:
|
| 16 |
+
img = t(img)
|
| 17 |
+
return img
|
| 18 |
+
|
| 19 |
+
class Resize:
|
| 20 |
+
def __init__(self, size, interpolation=cv.INTER_LINEAR):
|
| 21 |
+
self.size = size
|
| 22 |
+
self.interpolation = interpolation
|
| 23 |
+
|
| 24 |
+
def __call__(self, img):
|
| 25 |
+
return cv.resize(img, self.size)
|
| 26 |
+
|
| 27 |
+
class ColorConvert:
|
| 28 |
+
def __init__(self, ctype):
|
| 29 |
+
self.ctype = ctype
|
| 30 |
+
|
| 31 |
+
def __call__(self, img):
|
| 32 |
+
return cv.cvtColor(img, self.ctype)
|