KangLiao commited on
Commit
f18fdea
·
1 Parent(s): 0fae652
Files changed (1) hide show
  1. app.py +2 -10
app.py CHANGED
@@ -7,6 +7,7 @@ import math
7
  import re
8
  from einops import rearrange
9
  from mmengine.config import Config
 
10
 
11
  import matplotlib
12
  matplotlib.use("Agg")
@@ -15,10 +16,6 @@ import matplotlib.pyplot as plt
15
  from scripts.camera.cam_dataset import Cam_Generator
16
  from scripts.camera.visualization.visualize_batch import make_perspective_figures
17
 
18
- from mmengine.registry import Registry
19
- __all__ = ['BUILDER']
20
- BUILDER = Registry('builder')
21
-
22
  from huggingface_hub import snapshot_download
23
  import os
24
  local_path = snapshot_download(
@@ -45,16 +42,11 @@ def center_crop(image):
45
  ##### load model
46
  config = "configs/pipelines/stage_2_base.py"
47
  config = Config.fromfile(config)
48
- model = BUILDER.build(config.model).eval()
49
  checkpoint_path = "checkpoints/Puffin-Base.pth"
50
  checkpoint = torch.load(checkpoint_path)
51
  info = model.load_state_dict(checkpoint, strict=False)
52
 
53
- if torch.cuda.is_available():
54
- model = model.to(torch.bfloat16).cuda()
55
- else:
56
- model = model.to(torch.float32)
57
-
58
 
59
  @torch.inference_mode()
60
  @spaces.GPU(duration=120)
 
7
  import re
8
  from einops import rearrange
9
  from mmengine.config import Config
10
+ from xtuner.registry import BUILDER
11
 
12
  import matplotlib
13
  matplotlib.use("Agg")
 
16
  from scripts.camera.cam_dataset import Cam_Generator
17
  from scripts.camera.visualization.visualize_batch import make_perspective_figures
18
 
 
 
 
 
19
  from huggingface_hub import snapshot_download
20
  import os
21
  local_path = snapshot_download(
 
42
  ##### load model
43
  config = "configs/pipelines/stage_2_base.py"
44
  config = Config.fromfile(config)
45
+ model = BUILDER.build(config.model).cuda().bfloat16().eval()
46
  checkpoint_path = "checkpoints/Puffin-Base.pth"
47
  checkpoint = torch.load(checkpoint_path)
48
  info = model.load_state_dict(checkpoint, strict=False)
49
 
 
 
 
 
 
50
 
51
  @torch.inference_mode()
52
  @spaces.GPU(duration=120)