Spaces:
Runtime error
Runtime error
wait for init
Browse files- app.py +3 -2
- tools/llama/generate.py +4 -0
app.py
CHANGED
|
@@ -306,7 +306,7 @@ if __name__ == "__main__":
|
|
| 306 |
args.vqgan_config_name = "vqgan_pretrain"
|
| 307 |
|
| 308 |
logger.info("Loading Llama model...")
|
| 309 |
-
|
| 310 |
llama_queue = launch_thread_safe_queue(
|
| 311 |
config_name=args.llama_config_name,
|
| 312 |
checkpoint_path=args.llama_checkpoint_path,
|
|
@@ -314,11 +314,12 @@ if __name__ == "__main__":
|
|
| 314 |
precision=args.precision,
|
| 315 |
max_length=args.max_length,
|
| 316 |
compile=args.compile,
|
|
|
|
| 317 |
)
|
| 318 |
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
|
|
|
| 319 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
| 320 |
|
| 321 |
-
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
| 322 |
vqgan_model = load_vqgan_model(
|
| 323 |
config_name=args.vqgan_config_name,
|
| 324 |
checkpoint_path=args.vqgan_checkpoint_path,
|
|
|
|
| 306 |
args.vqgan_config_name = "vqgan_pretrain"
|
| 307 |
|
| 308 |
logger.info("Loading Llama model...")
|
| 309 |
+
init_event = threading.Event()
|
| 310 |
llama_queue = launch_thread_safe_queue(
|
| 311 |
config_name=args.llama_config_name,
|
| 312 |
checkpoint_path=args.llama_checkpoint_path,
|
|
|
|
| 314 |
precision=args.precision,
|
| 315 |
max_length=args.max_length,
|
| 316 |
compile=args.compile,
|
| 317 |
+
init_event=init_event,
|
| 318 |
)
|
| 319 |
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
| 320 |
+
init_event.wait()
|
| 321 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
| 322 |
|
|
|
|
| 323 |
vqgan_model = load_vqgan_model(
|
| 324 |
config_name=args.vqgan_config_name,
|
| 325 |
checkpoint_path=args.vqgan_checkpoint_path,
|
tools/llama/generate.py
CHANGED
|
@@ -607,6 +607,7 @@ def launch_thread_safe_queue(
|
|
| 607 |
precision,
|
| 608 |
max_length,
|
| 609 |
compile=False,
|
|
|
|
| 610 |
):
|
| 611 |
input_queue = queue.Queue()
|
| 612 |
|
|
@@ -615,6 +616,9 @@ def launch_thread_safe_queue(
|
|
| 615 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
| 616 |
)
|
| 617 |
|
|
|
|
|
|
|
|
|
|
| 618 |
while True:
|
| 619 |
item = input_queue.get()
|
| 620 |
if item is None:
|
|
|
|
| 607 |
precision,
|
| 608 |
max_length,
|
| 609 |
compile=False,
|
| 610 |
+
init_event=None,
|
| 611 |
):
|
| 612 |
input_queue = queue.Queue()
|
| 613 |
|
|
|
|
| 616 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
| 617 |
)
|
| 618 |
|
| 619 |
+
if init_event is not None:
|
| 620 |
+
init_event.set()
|
| 621 |
+
|
| 622 |
while True:
|
| 623 |
item = input_queue.get()
|
| 624 |
if item is None:
|