Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -272,7 +272,7 @@ class BeamNode:
|
|
| 272 |
is_selected_sequence: bool
|
| 273 |
|
| 274 |
|
| 275 |
-
def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences
|
| 276 |
original_tree = BeamNode(
|
| 277 |
cumulative_score=0,
|
| 278 |
current_token_ix=None,
|
|
@@ -415,8 +415,6 @@ def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequ
|
|
| 415 |
current_token_choice_ix = top_df_selected_filtered.iloc[beam_ix]["token_index"]
|
| 416 |
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
|
| 417 |
|
| 418 |
-
print(f"Step {step}, beams kept: {beams_to_keep}")
|
| 419 |
-
|
| 420 |
return original_tree
|
| 421 |
|
| 422 |
@spaces.GPU
|
|
@@ -445,14 +443,23 @@ def get_beam_search_html(
|
|
| 445 |
for i, sequence in enumerate(decoded_sequences):
|
| 446 |
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
|
| 447 |
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
html = generate_html(input_text, original_tree)
|
| 457 |
return html, markdown
|
| 458 |
|
|
|
|
| 272 |
is_selected_sequence: bool
|
| 273 |
|
| 274 |
|
| 275 |
+
def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences):
|
| 276 |
original_tree = BeamNode(
|
| 277 |
cumulative_score=0,
|
| 278 |
current_token_ix=None,
|
|
|
|
| 415 |
current_token_choice_ix = top_df_selected_filtered.iloc[beam_ix]["token_index"]
|
| 416 |
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
|
| 417 |
|
|
|
|
|
|
|
| 418 |
return original_tree
|
| 419 |
|
| 420 |
@spaces.GPU
|
|
|
|
| 443 |
for i, sequence in enumerate(decoded_sequences):
|
| 444 |
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
|
| 445 |
|
| 446 |
+
if number_beams > 1:
|
| 447 |
+
original_tree = generate_beams(
|
| 448 |
+
number_beams,
|
| 449 |
+
input_text,
|
| 450 |
+
outputs.scores[:],
|
| 451 |
+
length_penalty,
|
| 452 |
+
decoded_sequences,
|
| 453 |
+
)
|
| 454 |
+
else:
|
| 455 |
+
original_tree = generate_beams(
|
| 456 |
+
n_beams,
|
| 457 |
+
start_sentence,
|
| 458 |
+
outputs.logits,
|
| 459 |
+
0,
|
| 460 |
+
decoded_sequences,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
html = generate_html(input_text, original_tree)
|
| 464 |
return html, markdown
|
| 465 |
|