wlyu-adobe commited on
Commit
133857a
·
1 Parent(s): fc411f7

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +22 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -23,6 +23,7 @@ from easydict import EasyDict as edict
23
  from einops import rearrange
24
  from PIL import Image
25
  from huggingface_hub import snapshot_download
 
26
 
27
  # Install diff-gaussian-rasterization at runtime (requires GPU)
28
  import subprocess
@@ -98,14 +99,14 @@ class FaceLiftPipeline:
98
  self.image_size = 512
99
  self.camera_indices = [2, 1, 0, 5, 4, 3]
100
 
101
- # Load models
102
  print("Loading models...")
103
  self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
104
  str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"),
105
  torch_dtype=torch.float16,
106
  )
107
- self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
108
- self.mvdiffusion_pipeline.to(self.device)
109
 
110
  with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
111
  config = edict(yaml.safe_load(f))
@@ -120,11 +121,11 @@ class FaceLiftPipeline:
120
  map_location="cpu"
121
  )
122
  self.gs_lrm_model.load_state_dict(checkpoint["model"])
123
- self.gs_lrm_model.to(self.device)
124
 
125
  self.color_prompt_embedding = torch.load(
126
  workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt",
127
- map_location=self.device
128
  )
129
 
130
  with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
@@ -132,6 +133,18 @@ class FaceLiftPipeline:
132
 
133
  print("Models loaded successfully!")
134
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def _create_viewer_html(self, splat_path):
136
  """Create standalone HTML viewer for the gaussian splat."""
137
  import base64
@@ -246,9 +259,13 @@ class FaceLiftPipeline:
246
  </html>"""
247
  return html
248
 
 
249
  def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
250
  random_seed=4, num_steps=50):
251
  """Generate 3D head from single image."""
 
 
 
252
  try:
253
  # Setup output directory
254
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
23
  from einops import rearrange
24
  from PIL import Image
25
  from huggingface_hub import snapshot_download
26
+ import spaces
27
 
28
  # Install diff-gaussian-rasterization at runtime (requires GPU)
29
  import subprocess
 
99
  self.image_size = 512
100
  self.camera_indices = [2, 1, 0, 5, 4, 3]
101
 
102
+ # Load models (keep on CPU for ZeroGPU compatibility)
103
  print("Loading models...")
104
  self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
105
  str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"),
106
  torch_dtype=torch.float16,
107
  )
108
+ # Don't move to device or enable xformers here - will be done in GPU-decorated function
109
+ self._models_on_gpu = False
110
 
111
  with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
112
  config = edict(yaml.safe_load(f))
 
121
  map_location="cpu"
122
  )
123
  self.gs_lrm_model.load_state_dict(checkpoint["model"])
124
+ # Keep on CPU initially - will move to GPU in decorated function
125
 
126
  self.color_prompt_embedding = torch.load(
127
  workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt",
128
+ map_location="cpu"
129
  )
130
 
131
  with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
 
133
 
134
  print("Models loaded successfully!")
135
 
136
+ def _move_models_to_gpu(self):
137
+ """Move models to GPU and enable optimizations. Called within @spaces.GPU context."""
138
+ if not self._models_on_gpu and torch.cuda.is_available():
139
+ print("Moving models to GPU...")
140
+ self.device = torch.device("cuda:0")
141
+ self.mvdiffusion_pipeline.to(self.device)
142
+ self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
143
+ self.gs_lrm_model.to(self.device)
144
+ self.color_prompt_embedding = self.color_prompt_embedding.to(self.device)
145
+ self._models_on_gpu = True
146
+ print("Models on GPU, xformers enabled!")
147
+
148
  def _create_viewer_html(self, splat_path):
149
  """Create standalone HTML viewer for the gaussian splat."""
150
  import base64
 
259
  </html>"""
260
  return html
261
 
262
+ @spaces.GPU
263
  def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
264
  random_seed=4, num_steps=50):
265
  """Generate 3D head from single image."""
266
+ # Move models to GPU now that we're in the GPU context
267
+ self._move_models_to_gpu()
268
+
269
  try:
270
  # Setup output directory
271
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
requirements.txt CHANGED
@@ -25,3 +25,4 @@ jaxtyping==0.2.19
25
  pytorch-msssim==1.0.0
26
  ffmpeg-python==0.2.0
27
  tqdm
 
 
25
  pytorch-msssim==1.0.0
26
  ffmpeg-python==0.2.0
27
  tqdm
28
+ spaces