|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from models.isnet import ISNetDIS |
|
|
|
|
|
REPO_ID = "leonelhs/removators" |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
net = ISNetDIS() |
|
|
|
|
|
model_path = hf_hub_download(repo_id=REPO_ID, filename='isnet.pth') |
|
|
net.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
net.to(device) |
|
|
net.eval() |
|
|
|
|
|
dummy_input = torch.ones(1, 3, 1024, 1024) |
|
|
|
|
|
|
|
|
torch.onnx.export( |
|
|
net, |
|
|
dummy_input, |
|
|
"linear_model.onnx", |
|
|
input_names=["input"], |
|
|
output_names=["output"], |
|
|
dynamic_axes={ |
|
|
"input": {0: "batch_size"}, |
|
|
"output": {0: "batch_size"} |
|
|
}, |
|
|
opset_version=17 |
|
|
) |
|
|
|
|
|
print("Model exported to linear_model.onnx") |
|
|
|