Spaces:
Runtime error
Runtime error
Antoni Bigata
commited on
Commit
Β·
7746897
1
Parent(s):
2fb3e22
addapt for zerogpu
Browse files
app.py
CHANGED
|
@@ -23,6 +23,7 @@ from inference_functions import (
|
|
| 23 |
)
|
| 24 |
from wordle_game import WordleGame
|
| 25 |
import torch.cuda.amp as amp # Import amp for mixed precision
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
# Set default tensor type to float16 for faster computation
|
|
@@ -96,10 +97,26 @@ def load_model(
|
|
| 96 |
return model
|
| 97 |
|
| 98 |
|
| 99 |
-
#
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
vae_model = vae_model.half() # Convert to half precision
|
| 104 |
try:
|
| 105 |
vae_model = torch.compile(vae_model)
|
|
@@ -107,8 +124,7 @@ if torch.cuda.is_available():
|
|
| 107 |
except Exception as e:
|
| 108 |
print(f"Warning: Failed to compile vae_model: {e}")
|
| 109 |
|
| 110 |
-
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
| 111 |
-
if torch.cuda.is_available():
|
| 112 |
hubert_model = hubert_model.half() # Convert to half precision
|
| 113 |
try:
|
| 114 |
hubert_model = torch.compile(hubert_model)
|
|
@@ -116,13 +132,13 @@ if torch.cuda.is_available():
|
|
| 116 |
except Exception as e:
|
| 117 |
print(f"Warning: Failed to compile hubert_model: {e}")
|
| 118 |
|
| 119 |
-
wavlm_model = WavLM_wrapper(
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
).cuda()
|
| 125 |
-
|
| 126 |
wavlm_model = wavlm_model.half() # Convert to half precision
|
| 127 |
try:
|
| 128 |
wavlm_model = torch.compile(wavlm_model)
|
|
@@ -130,27 +146,23 @@ if torch.cuda.is_available():
|
|
| 130 |
except Exception as e:
|
| 131 |
print(f"Warning: Failed to compile wavlm_model: {e}")
|
| 132 |
|
| 133 |
-
landmarks_extractor = LandmarksExtractor()
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
os.path.dirname(__file__), "assets", "sample_video.mp4"
|
| 148 |
-
)
|
| 149 |
-
DEFAULT_AUDIO_PATH = os.path.join(
|
| 150 |
-
os.path.dirname(__file__), "assets", "sample_audio.wav"
|
| 151 |
-
)
|
| 152 |
|
| 153 |
|
|
|
|
| 154 |
@torch.no_grad()
|
| 155 |
def compute_video_embedding(video_reader, min_len):
|
| 156 |
"""Compute embeddings from video"""
|
|
@@ -200,6 +212,7 @@ def compute_video_embedding(video_reader, min_len):
|
|
| 200 |
return encoded, video_frames
|
| 201 |
|
| 202 |
|
|
|
|
| 203 |
@torch.no_grad()
|
| 204 |
def compute_hubert_embedding(raw_audio):
|
| 205 |
"""Compute embeddings from audio"""
|
|
@@ -246,6 +259,7 @@ def compute_hubert_embedding(raw_audio):
|
|
| 246 |
return audio_embeddings
|
| 247 |
|
| 248 |
|
|
|
|
| 249 |
@torch.no_grad()
|
| 250 |
def compute_wavlm_embedding(raw_audio):
|
| 251 |
"""Compute embeddings from audio"""
|
|
@@ -352,6 +366,7 @@ def extract_video_landmarks(video_frames):
|
|
| 352 |
return np.array(processed_landmarks)
|
| 353 |
|
| 354 |
|
|
|
|
| 355 |
@torch.no_grad()
|
| 356 |
def sample(
|
| 357 |
audio_list,
|
|
|
|
| 23 |
)
|
| 24 |
from wordle_game import WordleGame
|
| 25 |
import torch.cuda.amp as amp # Import amp for mixed precision
|
| 26 |
+
import spaces
|
| 27 |
|
| 28 |
|
| 29 |
# Set default tensor type to float16 for faster computation
|
|
|
|
| 97 |
return model
|
| 98 |
|
| 99 |
|
| 100 |
+
# Default media paths
|
| 101 |
+
DEFAULT_VIDEO_PATH = os.path.join(
|
| 102 |
+
os.path.dirname(__file__), "assets", "sample_video.mp4"
|
| 103 |
+
)
|
| 104 |
+
DEFAULT_AUDIO_PATH = os.path.join(
|
| 105 |
+
os.path.dirname(__file__), "assets", "sample_audio.wav"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@spaces.GPU(duration=60)
|
| 110 |
+
def load_all_models():
|
| 111 |
+
global \
|
| 112 |
+
keyframe_model, \
|
| 113 |
+
interpolation_model, \
|
| 114 |
+
vae_model, \
|
| 115 |
+
hubert_model, \
|
| 116 |
+
wavlm_model, \
|
| 117 |
+
landmarks_extractor
|
| 118 |
+
vae_model = VaeWrapper("video")
|
| 119 |
+
|
| 120 |
vae_model = vae_model.half() # Convert to half precision
|
| 121 |
try:
|
| 122 |
vae_model = torch.compile(vae_model)
|
|
|
|
| 124 |
except Exception as e:
|
| 125 |
print(f"Warning: Failed to compile vae_model: {e}")
|
| 126 |
|
| 127 |
+
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
|
|
|
| 128 |
hubert_model = hubert_model.half() # Convert to half precision
|
| 129 |
try:
|
| 130 |
hubert_model = torch.compile(hubert_model)
|
|
|
|
| 132 |
except Exception as e:
|
| 133 |
print(f"Warning: Failed to compile hubert_model: {e}")
|
| 134 |
|
| 135 |
+
wavlm_model = WavLM_wrapper(
|
| 136 |
+
model_size="Base+",
|
| 137 |
+
feed_as_frames=False,
|
| 138 |
+
merge_type="None",
|
| 139 |
+
model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
|
| 140 |
+
).cuda()
|
| 141 |
+
|
| 142 |
wavlm_model = wavlm_model.half() # Convert to half precision
|
| 143 |
try:
|
| 144 |
wavlm_model = torch.compile(wavlm_model)
|
|
|
|
| 146 |
except Exception as e:
|
| 147 |
print(f"Warning: Failed to compile wavlm_model: {e}")
|
| 148 |
|
| 149 |
+
landmarks_extractor = LandmarksExtractor()
|
| 150 |
+
keyframe_model = load_model(
|
| 151 |
+
config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
|
| 152 |
+
ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
|
| 153 |
+
)
|
| 154 |
+
interpolation_model = load_model(
|
| 155 |
+
config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
|
| 156 |
+
ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
|
| 157 |
+
)
|
| 158 |
+
keyframe_model.en_and_decode_n_samples_a_time = 2
|
| 159 |
+
interpolation_model.en_and_decode_n_samples_a_time = 2
|
| 160 |
|
| 161 |
+
|
| 162 |
+
load_all_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
+
@spaces.GPU(duration=60)
|
| 166 |
@torch.no_grad()
|
| 167 |
def compute_video_embedding(video_reader, min_len):
|
| 168 |
"""Compute embeddings from video"""
|
|
|
|
| 212 |
return encoded, video_frames
|
| 213 |
|
| 214 |
|
| 215 |
+
@spaces.GPU(duration=120)
|
| 216 |
@torch.no_grad()
|
| 217 |
def compute_hubert_embedding(raw_audio):
|
| 218 |
"""Compute embeddings from audio"""
|
|
|
|
| 259 |
return audio_embeddings
|
| 260 |
|
| 261 |
|
| 262 |
+
@spaces.GPU(duration=120)
|
| 263 |
@torch.no_grad()
|
| 264 |
def compute_wavlm_embedding(raw_audio):
|
| 265 |
"""Compute embeddings from audio"""
|
|
|
|
| 366 |
return np.array(processed_landmarks)
|
| 367 |
|
| 368 |
|
| 369 |
+
@spaces.GPU(duration=600)
|
| 370 |
@torch.no_grad()
|
| 371 |
def sample(
|
| 372 |
audio_list,
|