Spaces:
Runtime error
Runtime error
Fix cache max_seq_len
Browse files- app.py +2 -2
- tools/llama/generate.py +4 -2
app.py
CHANGED
|
@@ -414,7 +414,7 @@ def build_app():
|
|
| 414 |
label="Maximum tokens per batch, 0 means no limit",
|
| 415 |
minimum=0,
|
| 416 |
maximum=2048,
|
| 417 |
-
value=
|
| 418 |
step=8,
|
| 419 |
)
|
| 420 |
|
|
@@ -640,7 +640,7 @@ if __name__ == "__main__":
|
|
| 640 |
reference_audio=None,
|
| 641 |
reference_text="",
|
| 642 |
max_new_tokens=0,
|
| 643 |
-
chunk_length=
|
| 644 |
top_p=0.7,
|
| 645 |
repetition_penalty=1.2,
|
| 646 |
temperature=0.7,
|
|
|
|
| 414 |
label="Maximum tokens per batch, 0 means no limit",
|
| 415 |
minimum=0,
|
| 416 |
maximum=2048,
|
| 417 |
+
value=0, # 0 means no limit
|
| 418 |
step=8,
|
| 419 |
)
|
| 420 |
|
|
|
|
| 640 |
reference_audio=None,
|
| 641 |
reference_text="",
|
| 642 |
max_new_tokens=0,
|
| 643 |
+
chunk_length=200,
|
| 644 |
top_p=0.7,
|
| 645 |
repetition_penalty=1.2,
|
| 646 |
temperature=0.7,
|
tools/llama/generate.py
CHANGED
|
@@ -250,9 +250,11 @@ def generate(
|
|
| 250 |
device, dtype = prompt.device, prompt.dtype
|
| 251 |
with torch.device(device):
|
| 252 |
model.setup_caches(
|
| 253 |
-
|
|
|
|
|
|
|
| 254 |
)
|
| 255 |
-
|
| 256 |
codebook_dim = 1 + model.config.num_codebooks
|
| 257 |
# create an empty tensor of the expected final shape and fill in the current tokens
|
| 258 |
empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
|
|
|
|
| 250 |
device, dtype = prompt.device, prompt.dtype
|
| 251 |
with torch.device(device):
|
| 252 |
model.setup_caches(
|
| 253 |
+
max_batch_size=1,
|
| 254 |
+
max_seq_len=model.config.max_seq_len,
|
| 255 |
+
dtype=next(model.parameters()).dtype,
|
| 256 |
)
|
| 257 |
+
|
| 258 |
codebook_dim = 1 + model.config.num_codebooks
|
| 259 |
# create an empty tensor of the expected final shape and fill in the current tokens
|
| 260 |
empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
|