Spaces:
Running
on
Zero
Running
on
Zero
da03
commited on
Commit
·
0ad2aca
1
Parent(s):
3e9942b
app.py
CHANGED
|
@@ -37,11 +37,18 @@ def predict_product(num1, num2):
|
|
| 37 |
eos_token_id = tokenizer.eos_token_id
|
| 38 |
past_key_values = None
|
| 39 |
for _ in range(100): # Set a maximum limit to prevent infinite loops
|
| 40 |
-
outputs = model
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
past_key_values = outputs.past_key_values
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
break
|
| 46 |
|
| 47 |
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
|
|
| 37 |
eos_token_id = tokenizer.eos_token_id
|
| 38 |
past_key_values = None
|
| 39 |
for _ in range(100): # Set a maximum limit to prevent infinite loops
|
| 40 |
+
outputs = model(
|
| 41 |
+
input_ids=generated_ids,
|
| 42 |
+
past_key_values=past_key_values,
|
| 43 |
+
use_cache=True
|
| 44 |
+
)
|
| 45 |
+
logits = outputs.logits
|
| 46 |
past_key_values = outputs.past_key_values
|
| 47 |
|
| 48 |
+
next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
|
| 49 |
+
generated_ids = torch.cat((generated_ids, next_token_id.unsqueeze(-1)), dim=-1)
|
| 50 |
+
|
| 51 |
+
if next_token_id.item() == eos_token_id:
|
| 52 |
break
|
| 53 |
|
| 54 |
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|