Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4a69f28
1
Parent(s):
8d5f45b
Fix MPI library issue
Browse files- app.py +5 -1
- requirements.txt +3 -0
app.py
CHANGED
|
@@ -31,7 +31,7 @@ from train_settings.dvd.eval_utils import extract_raw_features_single,extract_ra
|
|
| 31 |
from datasets.utils.warping import register_model2
|
| 32 |
|
| 33 |
import gradio as gr
|
| 34 |
-
|
| 35 |
|
| 36 |
|
| 37 |
|
|
@@ -515,6 +515,7 @@ model, diffusion = create_model_and_diffusion(
|
|
| 515 |
setattr(diffusion, "settings", settings)
|
| 516 |
|
| 517 |
pretrained_dewarp_model = GeoTr_Seg_Inf()
|
|
|
|
| 518 |
reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
|
| 519 |
# reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
|
| 520 |
pretrained_dewarp_model.to(dist_util.dev())
|
|
@@ -523,16 +524,19 @@ pretrained_dewarp_model.eval()
|
|
| 523 |
if settings.env.use_line_mask:
|
| 524 |
pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
|
| 525 |
pretrained_seg_model = Seg()
|
|
|
|
| 526 |
line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
|
| 527 |
pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
|
| 528 |
pretrained_line_seg_model.to(dist_util.dev())
|
| 529 |
pretrained_line_seg_model.eval()
|
| 530 |
|
|
|
|
| 531 |
seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
|
| 532 |
pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
|
| 533 |
pretrained_seg_model.to(dist_util.dev())
|
| 534 |
pretrained_seg_model.eval()
|
| 535 |
|
|
|
|
| 536 |
model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
|
| 537 |
logger.log(f"Model loaded with {settings.env.model_path}")
|
| 538 |
|
|
|
|
| 31 |
from datasets.utils.warping import register_model2
|
| 32 |
|
| 33 |
import gradio as gr
|
| 34 |
+
from huggingface_hub import hf_hub_download
|
| 35 |
|
| 36 |
|
| 37 |
|
|
|
|
| 515 |
setattr(diffusion, "settings", settings)
|
| 516 |
|
| 517 |
pretrained_dewarp_model = GeoTr_Seg_Inf()
|
| 518 |
+
settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/dvd", filename="seg.pth")
|
| 519 |
reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
|
| 520 |
# reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
|
| 521 |
pretrained_dewarp_model.to(dist_util.dev())
|
|
|
|
| 524 |
if settings.env.use_line_mask:
|
| 525 |
pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
|
| 526 |
pretrained_seg_model = Seg()
|
| 527 |
+
settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/dvd", filename="line_model2.pth")
|
| 528 |
line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
|
| 529 |
pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
|
| 530 |
pretrained_line_seg_model.to(dist_util.dev())
|
| 531 |
pretrained_line_seg_model.eval()
|
| 532 |
|
| 533 |
+
settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/dvd", filename="seg_model.pth")
|
| 534 |
seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
|
| 535 |
pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
|
| 536 |
pretrained_seg_model.to(dist_util.dev())
|
| 537 |
pretrained_seg_model.eval()
|
| 538 |
|
| 539 |
+
settings.env.model_path = hf_hub_download(repo_id="hanquansanren/dvd", filename="model1852000.pt")
|
| 540 |
model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
|
| 541 |
logger.log(f"Model loaded with {settings.env.model_path}")
|
| 542 |
|
requirements.txt
CHANGED
|
@@ -3,7 +3,10 @@
|
|
| 3 |
# pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118
|
| 4 |
# pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html
|
| 5 |
# gxx_linux-64
|
|
|
|
| 6 |
mpi4py
|
|
|
|
|
|
|
| 7 |
opencv-python
|
| 8 |
jpeg4py
|
| 9 |
packaging
|
|
|
|
| 3 |
# pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118
|
| 4 |
# pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html
|
| 5 |
# gxx_linux-64
|
| 6 |
+
mmcv==2.2.0
|
| 7 |
mpi4py
|
| 8 |
+
matplotlib
|
| 9 |
+
huggingface_hub
|
| 10 |
opencv-python
|
| 11 |
jpeg4py
|
| 12 |
packaging
|