OnnxAnalysis / app.py
Arrcttacsrks's picture
Update app.py
47b0e83 verified
import gradio as gr
import onnx
import numpy as np
from collections import defaultdict
import os
import json
from datetime import datetime
import onnx.numpy_helper
from huggingface_hub import HfApi
from tempfile import NamedTemporaryFile
import time
from huggingface_hub import HfApi, login
def analyze_weight_data(tensor):
"""Analyze tensor data"""
try:
np_array = onnx.numpy_helper.to_array(tensor)
info = {
'name': tensor.name,
'shape': list(np_array.shape),
'data_type': onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type].__name__,
'size': np_array.size,
'bytes': np_array.nbytes,
'statistics': {
'min': float(np_array.min()),
'max': float(np_array.max()),
'mean': float(np_array.mean()),
'std': float(np_array.std()),
'non_zero': int(np.count_nonzero(np_array)),
'zero_count': int(np_array.size - np.count_nonzero(np_array)),
'unique_values': int(len(np.unique(np_array))),
},
'distribution': {
'percentiles': {
'1%': float(np.percentile(np_array, 1)),
'25%': float(np.percentile(np_array, 25)),
'50%': float(np.percentile(np_array, 50)),
'75%': float(np.percentile(np_array, 75)),
'99%': float(np.percentile(np_array, 99)),
},
'sparsity': float(np.count_nonzero(np_array == 0) / np_array.size),
},
}
if np_array.size <= 100:
info['sample_values'] = np_array.tolist()
else:
info['sample_values'] = {
'first_10': np_array.flatten()[:10].tolist(),
'last_10': np_array.flatten()[-10:].tolist()
}
return info
except Exception as e:
return {
'name': tensor.name,
'error': f"Analysis failed: {str(e)}"
}
def analyze_model_structure(graph):
"""Analyze model structure"""
structure = {
'nodes': [],
'connections': defaultdict(list),
'input_nodes': [],
'output_nodes': [],
}
for node in graph.node:
node_info = {
'name': node.name or f"node_{len(structure['nodes'])}",
'op_type': node.op_type,
'inputs': list(node.input),
'outputs': list(node.output),
'attributes': {}
}
structure['nodes'].append(node_info)
for input_name in node.input:
structure['connections'][input_name].append(node_info['name'])
structure['input_nodes'] = [input.name for input in graph.input]
structure['output_nodes'] = [output.name for output in graph.output]
return structure
def analyze_onnx_model(model_path):
"""Analyze ONNX model"""
model = onnx.load(model_path)
graph = model.graph
analysis = {
'model_info': {
'ir_version': str(model.ir_version),
'producer_name': model.producer_name,
'producer_version': model.producer_version,
'domain': model.domain,
'model_version': str(model.model_version),
'doc_string': model.doc_string,
},
'structure': analyze_model_structure(graph),
'weights_analysis': {},
'computation_info': {
'total_params': 0,
'total_memory': 0,
'layer_stats': defaultdict(int)
}
}
for tensor in graph.initializer:
weight_info = analyze_weight_data(tensor)
analysis['weights_analysis'][tensor.name] = weight_info
if 'shape' in weight_info:
analysis['computation_info']['total_params'] += np.prod(weight_info['shape'])
analysis['computation_info']['total_memory'] += weight_info.get('bytes', 0)
for node in graph.node:
analysis['computation_info']['layer_stats'][node.op_type] += 1
return analysis
def format_analysis_text(analysis):
"""Format analysis results as text"""
text = "=== MODEL INFORMATION ===\n"
for key, value in analysis['model_info'].items():
if value:
text += f"{key}: {value}\n"
text += "\n=== NETWORK STRUCTURE ===\n"
text += f"Total layers: {len(analysis['structure']['nodes'])}\n"
text += f"Input nodes: {', '.join(analysis['structure']['input_nodes'])}\n"
text += f"Output nodes: {', '.join(analysis['structure']['output_nodes'])}\n"
text += "\n=== LAYER STATISTICS ===\n"
for op_type, count in analysis['computation_info']['layer_stats'].items():
text += f"{op_type}: {count} layers\n"
text += "\n=== WEIGHTS ANALYSIS ===\n"
total_params = analysis['computation_info']['total_params']
total_memory = analysis['computation_info']['total_memory']
text += f"Total parameters: {total_params:,}\n"
text += f"Total memory usage: {total_memory/1024/1024:.2f} MB\n"
return text
def save_to_hf_dataset(analysis, model_name):
"""
Save analysis results to Hugging Face dataset using token from environment variable
Args:
analysis: Analysis results to save
model_name: Name of the model being analyzed
"""
# Get token from environment variable
hf_token = os.getenv('HF_TOKEN')
if not hf_token:
raise ValueError("HF_TOKEN environment variable not found")
# Login with token
login(token=hf_token)
# Initialize API
api = HfApi()
# Create temporary files
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
with NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as txt_file:
txt_file.write(format_analysis_text(analysis))
txt_path = txt_file.name
with NamedTemporaryFile(mode='w', suffix='.json', delete=False) as json_file:
json.dump(analysis, json_file, indent=2)
json_path = json_file.name
try:
# Upload files to dataset
api.upload_file(
path_or_fileobj=txt_path,
path_in_repo=f"analysis_{model_name}_{timestamp}.txt",
repo_id="Arrcttacsrks/OnnxAnalysisData",
repo_type="dataset"
)
api.upload_file(
path_or_fileobj=json_path,
path_in_repo=f"analysis_{model_name}_{timestamp}.json",
repo_id="Arrcttacsrks/OnnxAnalysisData",
repo_type="dataset"
)
finally:
# Clean up temporary files
os.unlink(txt_path)
os.unlink(json_path)
def analyze_and_save(model_file):
"""Main function for Gradio interface"""
try:
# Analyze model
analysis = analyze_onnx_model(model_file.name)
# Format results
text_output = format_analysis_text(analysis)
# Save to dataset
model_name = os.path.splitext(os.path.basename(model_file.name))[0]
save_to_hf_dataset(analysis, model_name)
# Create download files
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
txt_file = f"analysis_{model_name}_{timestamp}.txt"
json_file = f"analysis_{model_name}_{timestamp}.json"
with open(txt_file, 'w') as f:
f.write(text_output)
with open(json_file, 'w') as f:
json.dump(analysis, f, indent=2)
return text_output, txt_file, json_file
except Exception as e:
return f"Error: {str(e)}", None, None
# Create Gradio interface
with gr.Blocks(title="ONNX Model Analyzer") as demo:
gr.Markdown("# ONNX Model Analyzer")
gr.Markdown("Upload an ONNX model to analyze its structure and parameters.")
with gr.Row():
input_file = gr.File(label="Upload ONNX Model")
with gr.Row():
analyze_btn = gr.Button("Analyze Model")
with gr.Row():
output_text = gr.Textbox(label="Analysis Results", lines=20)
with gr.Row():
txt_output = gr.File(label="Download TXT Report")
json_output = gr.File(label="Download JSON Report")
analyze_btn.click(
fn=analyze_and_save,
inputs=[input_file],
outputs=[output_text, txt_output, json_output]
)
# Launch app
if __name__ == "__main__":
demo.launch()