hanquansanren commited on
Commit
4a69f28
·
1 Parent(s): 8d5f45b

Fix MPI library issue

Browse files
Files changed (2) hide show
  1. app.py +5 -1
  2. requirements.txt +3 -0
app.py CHANGED
@@ -31,7 +31,7 @@ from train_settings.dvd.eval_utils import extract_raw_features_single,extract_ra
31
  from datasets.utils.warping import register_model2
32
 
33
  import gradio as gr
34
-
35
 
36
 
37
 
@@ -515,6 +515,7 @@ model, diffusion = create_model_and_diffusion(
515
  setattr(diffusion, "settings", settings)
516
 
517
  pretrained_dewarp_model = GeoTr_Seg_Inf()
 
518
  reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
519
  # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
520
  pretrained_dewarp_model.to(dist_util.dev())
@@ -523,16 +524,19 @@ pretrained_dewarp_model.eval()
523
  if settings.env.use_line_mask:
524
  pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
525
  pretrained_seg_model = Seg()
 
526
  line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
527
  pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
528
  pretrained_line_seg_model.to(dist_util.dev())
529
  pretrained_line_seg_model.eval()
530
 
 
531
  seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
532
  pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
533
  pretrained_seg_model.to(dist_util.dev())
534
  pretrained_seg_model.eval()
535
 
 
536
  model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
537
  logger.log(f"Model loaded with {settings.env.model_path}")
538
 
 
31
  from datasets.utils.warping import register_model2
32
 
33
  import gradio as gr
34
+ from huggingface_hub import hf_hub_download
35
 
36
 
37
 
 
515
  setattr(diffusion, "settings", settings)
516
 
517
  pretrained_dewarp_model = GeoTr_Seg_Inf()
518
+ settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/dvd", filename="seg.pth")
519
  reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
520
  # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
521
  pretrained_dewarp_model.to(dist_util.dev())
 
524
  if settings.env.use_line_mask:
525
  pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
526
  pretrained_seg_model = Seg()
527
+ settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/dvd", filename="line_model2.pth")
528
  line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
529
  pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
530
  pretrained_line_seg_model.to(dist_util.dev())
531
  pretrained_line_seg_model.eval()
532
 
533
+ settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/dvd", filename="seg_model.pth")
534
  seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
535
  pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
536
  pretrained_seg_model.to(dist_util.dev())
537
  pretrained_seg_model.eval()
538
 
539
+ settings.env.model_path = hf_hub_download(repo_id="hanquansanren/dvd", filename="model1852000.pt")
540
  model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
541
  logger.log(f"Model loaded with {settings.env.model_path}")
542
 
requirements.txt CHANGED
@@ -3,7 +3,10 @@
3
  # pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118
4
  # pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html
5
  # gxx_linux-64
 
6
  mpi4py
 
 
7
  opencv-python
8
  jpeg4py
9
  packaging
 
3
  # pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118
4
  # pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html
5
  # gxx_linux-64
6
+ mmcv==2.2.0
7
  mpi4py
8
+ matplotlib
9
+ huggingface_hub
10
  opencv-python
11
  jpeg4py
12
  packaging