Spaces:
Running
Running
| import torch | |
| import torch.onnx | |
| import onnx | |
| from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder | |
| from VitsModelSplit.vits_model import VitsModel | |
| import gradio as gr | |
| import os | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-ar") | |
| def create_file(file_path): | |
| # ู ุณุงุฑ ุงูู ูู ุงูู ุคูุช | |
| #file_path = "DDFGDdd.onnx" | |
| # ุฅูุดุงุก ู ูู ONNX ุชุฌุฑูุจู ูู ุญุงูุฉ ุนุฏู ูุฌูุฏู | |
| if not os.path.exists(file_path): | |
| #with open(file_path, "w") as file: | |
| #file.write("This is a test ONNX model file.") | |
| return None | |
| # ุฅุฑุฌุงุน ู ุณุงุฑ ุงูู ูู ุญุชู ูู ูู ุชูุฒููู | |
| return file_path | |
| class OnnxModelConverter: | |
| def __init__(self): | |
| self.model = None | |
| def download_file(self,file_path): | |
| if not os.path.exists(file_path): | |
| #with open(file_path, "w") as file: | |
| #file.write("This is a test ONNX model.") | |
| return None | |
| return file_path | |
| def convert(self, model_name, token, onnx_filename, conversion_type): | |
| """ | |
| Main function to handle different types of model conversions. | |
| Args: | |
| model_name (str): Name of the model to convert. | |
| token (str): Access token for loading the model. | |
| onnx_filename (str): Desired filename for the ONNX output. | |
| conversion_type (str): Type of conversion ('decoder', 'only_decoder', or 'full_model'). | |
| Returns: | |
| str: The path to the generated ONNX file. | |
| """ | |
| if conversion_type == "decoder": | |
| return self.convert_decoder(model_name, token, onnx_filename) | |
| elif conversion_type == "only_decoder": | |
| return self.convert_only_decoder(model_name, token, onnx_filename) | |
| elif conversion_type == "full_model": | |
| return self.convert_full_model(model_name, token, onnx_filename) | |
| else: | |
| raise ValueError("Invalid conversion type. Choose from 'decoder', 'only_decoder', or 'full_model'.") | |
| def convert_decoder(self, model_name, token, onnx_filename): | |
| """ | |
| Converts only the decoder part of the Vits model to ONNX format. | |
| Args: | |
| model_name (str): Name of the model to convert. | |
| token (str): Access token for loading the model. | |
| onnx_filename (str): Desired filename for the ONNX output. | |
| Returns: | |
| str: The path to the generated ONNX file. | |
| """ | |
| model = VitsModel.from_pretrained(model_name, token=token) | |
| onnx_file = f"/tmp/{onnx_filename}.onnx" | |
| example_input = torch.randn(1, 192, 10) | |
| torch.onnx.export( | |
| model.decoder, | |
| example_input, | |
| onnx_file, | |
| opset_version=11, | |
| input_names=['input'], | |
| output_names=['output'], | |
| dynamic_axes={"input": {0: "batch_size", 2: "seq_len"}, | |
| "output": {0: "batch_size", 1: "sequence_length"}} | |
| ) | |
| return self.download_file(onnx_file) | |
| def convert_only_decoder(self, model_name, token, onnx_filename): | |
| """ | |
| Converts only the decoder part of the Vits model to ONNX format. | |
| Args: | |
| model_name (str): Name of the model to convert. | |
| token (str): Access token for loading the model. | |
| onnx_filename (str): Desired filename for the ONNX output. | |
| Returns: | |
| str: The path to the generated ONNX file. | |
| """ | |
| model = Vits_models_only_decoder.from_pretrained(model_name, token=token) | |
| onnx_file = f"/tmp/{onnx_filename}.onnx" | |
| inputs = tokenizer("ุงูุณูุงู ุนูููู ููู ุงูุญุงู", return_tensors="pt") | |
| # Trace the decoder part of the model | |
| example_inputs = inputs.input_ids.type(torch.LongTensor) | |
| torch.onnx.export(model, | |
| example_inputs, | |
| onnx_file, | |
| input_names=["input"], | |
| output_names=["output"], | |
| dynamic_axes={"input": {0: "batch_size", 1: "sequence_length"}, | |
| "output": {0: "batch_size", 1: "sequence_length"}}) | |
| return self.download_file(onnx_file) | |
| def convert_full_model(self, model_name, token, onnx_filename): | |
| """ | |
| Converts the full Vits model (including encoder and decoder) to ONNX format. | |
| Args: | |
| model_name (str): Name of the model to convert. | |
| token (str): Access token for loading the model. | |
| onnx_filename (str): Desired filename for the ONNX output. | |
| Returns: | |
| str: The path to the generated ONNX file. | |
| """ | |
| model = VitsModel.from_pretrained(model_name, token=token) | |
| onnx_file = f"/tmp/{onnx_filename}.onnx" | |
| vocab_size = model.text_encoder.embed_tokens.weight.size(0) | |
| example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long) | |
| torch.onnx.export( | |
| model, | |
| example_input, | |
| onnx_file, | |
| opset_version=11, | |
| input_names=['input'], | |
| output_names=['output'], | |
| dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}} | |
| ) | |
| return self.download_file(onnx_file) | |
| def starrt(self): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_n_model=gr.Textbox(label="name model") | |
| text_n_token=gr.Textbox(label="token") | |
| text_n_onxx=gr.Textbox(label="name model onxx") | |
| choice = gr.Dropdown(choices=["decoder", "only_decoder", "full_model"], label="My Dropdown") | |
| with gr.Column(): | |
| btn=gr.Button("convert") | |
| label=gr.Label("return name model onxx") | |
| btn.click(self.convert,[text_n_model,text_n_token,text_n_onxx,choice],[gr.File(label="Download ONNX File")]) | |
| btx=gr.Textbox("namefile") | |
| download_button1=gr.Button("send") | |
| download_button = gr.File(label="Download ONNX File") | |
| download_button1.click(create_file,[btx],[download_button]) | |
| #choice.change(fn=function_change, inputs=choice, outputs=label) | |
| return demo | |
| c=OnnxModelConverter() | |
| cc=c.starrt() | |
| cc.launch(share=True) | |