dis_onnx / export_onnx.py
leonelhs's picture
change to onnx
12b5ec5
raw
history blame contribute delete
905 Bytes
import torch
from huggingface_hub import hf_hub_download
from models.isnet import ISNetDIS
REPO_ID = "leonelhs/removators"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = ISNetDIS()
model_path = hf_hub_download(repo_id=REPO_ID, filename='isnet.pth')
net.load_state_dict(torch.load(model_path, map_location=device))
net.to(device)
net.eval()
dummy_input = torch.ones(1, 3, 1024, 1024)
# Export the model
torch.onnx.export(
net, # model
dummy_input, # example input
"linear_model.onnx", # output file
input_names=["input"], # name inputs
output_names=["output"], # name outputs
dynamic_axes={ # allow variable batch size
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
opset_version=17 # ONNX version
)
print("Model exported to linear_model.onnx")