|  | import torch | 
					
						
						|  | from transformers import AutoModelForQuestionAnswering | 
					
						
						|  | from transformers import AutoTokenizer, BertConfig | 
					
						
						|  | import onnx | 
					
						
						|  | from onnxruntime.quantization import quantize_dynamic, QuantType | 
					
						
						|  | from onnxruntime.quantization import shape_inference | 
					
						
						|  | import os | 
					
						
						|  | import logging | 
					
						
						|  | from typing import Optional, Dict, Any | 
					
						
						|  | import subprocess | 
					
						
						|  |  | 
					
						
						|  | class ONNXModelConverter: | 
					
						
						|  | def __init__(self, model_name: str, output_dir: str): | 
					
						
						|  | self.model_name = model_name | 
					
						
						|  | self.output_dir = output_dir | 
					
						
						|  | self.setup_logging() | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(output_dir, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  | self.logger.info(f"Loading tokenizer {model_name}...") | 
					
						
						|  | self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | 
					
						
						|  |  | 
					
						
						|  | self.logger.info(f"Loading model {model_name}...") | 
					
						
						|  | self.model = AutoModelForQuestionAnswering.from_pretrained( | 
					
						
						|  | model_name, | 
					
						
						|  | trust_remote_code=True, | 
					
						
						|  | torch_dtype=torch.float32 | 
					
						
						|  | ) | 
					
						
						|  | self.model.eval() | 
					
						
						|  |  | 
					
						
						|  | def setup_logging(self): | 
					
						
						|  | self.logger = logging.getLogger(__name__) | 
					
						
						|  | self.logger.setLevel(logging.INFO) | 
					
						
						|  | handler = logging.StreamHandler() | 
					
						
						|  | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | 
					
						
						|  | handler.setFormatter(formatter) | 
					
						
						|  | self.logger.addHandler(handler) | 
					
						
						|  |  | 
					
						
						|  | def prepare_dummy_inputs(self): | 
					
						
						|  | dummy_input = self.tokenizer( | 
					
						
						|  | "Hello, how are you?", | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | padding=True, | 
					
						
						|  | truncation=True, | 
					
						
						|  | max_length=128 | 
					
						
						|  | ) | 
					
						
						|  | return { | 
					
						
						|  | 'input_ids': dummy_input['input_ids'], | 
					
						
						|  | 'attention_mask': dummy_input['attention_mask'], | 
					
						
						|  | 'token_type_ids': dummy_input['token_type_ids'] | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def export_to_onnx(self): | 
					
						
						|  | output_path = os.path.join(self.output_dir, "model.onnx") | 
					
						
						|  | inputs = self.prepare_dummy_inputs() | 
					
						
						|  |  | 
					
						
						|  | dynamic_axes = { | 
					
						
						|  | 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, | 
					
						
						|  | 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, | 
					
						
						|  | 'token_type_ids': {0: 'batch_size', 1: 'sequence_length'}, | 
					
						
						|  | 'start_logits': {0: 'batch_size', 1: 'sequence_length'}, | 
					
						
						|  | 'end_logits': {0: 'batch_size', 1: 'sequence_length'}, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | class ModelWrapper(torch.nn.Module): | 
					
						
						|  | def __init__(self, model): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.model = model | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input_ids, attention_mask, token_type_ids): | 
					
						
						|  | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | 
					
						
						|  | return outputs.start_logits, outputs.end_logits | 
					
						
						|  |  | 
					
						
						|  | wrapped_model = ModelWrapper(self.model) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | torch.onnx.export( | 
					
						
						|  | wrapped_model, | 
					
						
						|  | (inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']), | 
					
						
						|  | output_path, | 
					
						
						|  | export_params=True, | 
					
						
						|  | opset_version=14, | 
					
						
						|  | do_constant_folding=True, | 
					
						
						|  | input_names=['input_ids', 'attention_mask', 'token_type_ids'], | 
					
						
						|  | output_names=['start_logits', 'end_logits'], | 
					
						
						|  | dynamic_axes=dynamic_axes, | 
					
						
						|  | verbose=False | 
					
						
						|  | ) | 
					
						
						|  | self.logger.info(f"Model exported to {output_path}") | 
					
						
						|  | return output_path | 
					
						
						|  | except Exception as e: | 
					
						
						|  | self.logger.error(f"ONNX export failed: {str(e)}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | def verify_model(self, model_path: str): | 
					
						
						|  | try: | 
					
						
						|  | onnx_model = onnx.load(model_path) | 
					
						
						|  | onnx.checker.check_model(onnx_model) | 
					
						
						|  | self.logger.info("ONNX model verification successful") | 
					
						
						|  | return True | 
					
						
						|  | except Exception as e: | 
					
						
						|  | self.logger.error(f"Model verification failed: {str(e)}") | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  | def preprocess_model(self, model_path: str) -> str: | 
					
						
						|  | preprocessed_path = os.path.join(self.output_dir, "model-infer.onnx") | 
					
						
						|  | try: | 
					
						
						|  | command = [ | 
					
						
						|  | "python", "-m", "onnxruntime.quantization.preprocess", | 
					
						
						|  | "--input", model_path, | 
					
						
						|  | "--output", preprocessed_path | 
					
						
						|  | ] | 
					
						
						|  | result = subprocess.run(command, check=True, capture_output=True, text=True) | 
					
						
						|  | if result.returncode == 0: | 
					
						
						|  | self.logger.info(f"Model preprocessing successful. Output saved to {preprocessed_path}") | 
					
						
						|  | return preprocessed_path | 
					
						
						|  | else: | 
					
						
						|  | raise subprocess.CalledProcessError(result.returncode, command, result.stdout, result.stderr) | 
					
						
						|  | except subprocess.CalledProcessError as e: | 
					
						
						|  | self.logger.error(f"Preprocessing failed: {e.stderr}") | 
					
						
						|  | raise | 
					
						
						|  | except Exception as e: | 
					
						
						|  | self.logger.error(f"Preprocessing failed: {str(e)}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | def quantize_model(self, model_path: str): | 
					
						
						|  | weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16} | 
					
						
						|  | all_quantized_paths = [] | 
					
						
						|  | for weight_type in weight_types.keys(): | 
					
						
						|  | quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx") | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | quantize_dynamic( | 
					
						
						|  | model_path, | 
					
						
						|  | quantized_path, | 
					
						
						|  | weight_type=weight_types[weight_type] | 
					
						
						|  | ) | 
					
						
						|  | self.logger.info(f"Model quantized ({weight_type}) and saved to {quantized_path}") | 
					
						
						|  | all_quantized_paths.append(quantized_path) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | self.logger.error(f"Quantization ({weight_type}) failed: {str(e)}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | return all_quantized_paths | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert(self): | 
					
						
						|  | try: | 
					
						
						|  | onnx_path = self.export_to_onnx() | 
					
						
						|  |  | 
					
						
						|  | if self.verify_model(onnx_path): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | quantized_paths = self.quantize_model(onnx_path) | 
					
						
						|  |  | 
					
						
						|  | tokenizer_path = os.path.join(self.output_dir, "tokenizer") | 
					
						
						|  | self.tokenizer.save_pretrained(tokenizer_path) | 
					
						
						|  | self.logger.info(f"Tokenizer saved to {tokenizer_path}") | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | 'onnx_model': onnx_path, | 
					
						
						|  | 'quantized_models': quantized_paths, | 
					
						
						|  | 'tokenizer': tokenizer_path | 
					
						
						|  | } | 
					
						
						|  | else: | 
					
						
						|  | raise Exception("Model verification failed") | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | self.logger.error(f"Conversion process failed: {str(e)}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | MODEL_NAME = "Intel/dynamic_tinybert" | 
					
						
						|  | OUTPUT_DIR = "onnx" | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR) | 
					
						
						|  | results = converter.convert() | 
					
						
						|  |  | 
					
						
						|  | print("\nConversion completed successfully!") | 
					
						
						|  | print(f"ONNX model path: {results['onnx_model']}") | 
					
						
						|  | print(f"Quantized model paths: {results['quantized_models']}") | 
					
						
						|  | print(f"Tokenizer path: {results['tokenizer']}") | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Conversion failed: {str(e)}") | 
					
						
						|  |  |