Spaces:
				
			
			
	
			
			
		No application file
		
	
	
	
			
			
	
	
	
	
		
		
		No application file
		
	Upload 2 files
Browse files- app.py +51 -0
 - requirements.txt +5 -0
 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,51 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import onnxruntime as ort
         
     | 
| 4 | 
         
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM
         
     | 
| 5 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            HF_MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
         
     | 
| 9 | 
         
            +
            HF_ONNX_REPO = "techAInewb/mistral-nemo-2407-fp32"
         
     | 
| 10 | 
         
            +
            ONNX_MODEL_FILE = "model.onnx"
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # Load tokenizer
         
     | 
| 13 | 
         
            +
            tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # Load PyTorch model
         
     | 
| 16 | 
         
            +
            pt_model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID, torch_dtype=torch.float32)
         
     | 
| 17 | 
         
            +
            pt_model.eval()
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            # Load ONNX model
         
     | 
| 20 | 
         
            +
            onnx_path = hf_hub_download(repo_id=HF_ONNX_REPO, filename=ONNX_MODEL_FILE)
         
     | 
| 21 | 
         
            +
            onnx_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def compare_outputs(prompt):
         
     | 
| 24 | 
         
            +
                inputs = tokenizer(prompt, return_tensors="np", padding=False)
         
     | 
| 25 | 
         
            +
                torch_inputs = tokenizer(prompt, return_tensors="pt")
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                # Run PyTorch
         
     | 
| 28 | 
         
            +
                with torch.no_grad():
         
     | 
| 29 | 
         
            +
                    pt_outputs = pt_model(**torch_inputs).logits
         
     | 
| 30 | 
         
            +
                pt_top = torch.topk(pt_outputs[0, -1], 5).indices.tolist()
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                # Run ONNX
         
     | 
| 33 | 
         
            +
                ort_outputs = onnx_session.run(None, {
         
     | 
| 34 | 
         
            +
                    "input_ids": inputs["input_ids"],
         
     | 
| 35 | 
         
            +
                    "attention_mask": inputs["attention_mask"]
         
     | 
| 36 | 
         
            +
                })
         
     | 
| 37 | 
         
            +
                ort_logits = ort_outputs[0]
         
     | 
| 38 | 
         
            +
                ort_top = np.argsort(ort_logits[0, -1])[::-1][:5].tolist()
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                pt_tokens = tokenizer.convert_ids_to_tokens(pt_top)
         
     | 
| 41 | 
         
            +
                ort_tokens = tokenizer.convert_ids_to_tokens(ort_top)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                return f"PyTorch Top Tokens: {pt_tokens}", f"ONNX Top Tokens: {ort_tokens}"
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            iface = gr.Interface(fn=compare_outputs,
         
     | 
| 46 | 
         
            +
                                 inputs=gr.Textbox(lines=2, placeholder="Enter a prompt..."),
         
     | 
| 47 | 
         
            +
                                 outputs=["text", "text"],
         
     | 
| 48 | 
         
            +
                                 title="ONNX vs PyTorch Model Comparison",
         
     | 
| 49 | 
         
            +
                                 description="Run both PyTorch and ONNX inference on a prompt and compare top predicted tokens.")
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            iface.launch()
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            gradio
         
     | 
| 2 | 
         
            +
            transformers
         
     | 
| 3 | 
         
            +
            torch
         
     | 
| 4 | 
         
            +
            onnxruntime
         
     | 
| 5 | 
         
            +
            huggingface_hub
         
     |