Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -353,31 +353,80 @@ def generate_code_from_pseudo(pseudo_code, max_length, temperature, top_k, top_p
|
|
| 353 |
|
| 354 |
generation_time = time.time() - start_time
|
| 355 |
|
| 356 |
-
# Decode all sequences
|
| 357 |
generated_codes = []
|
| 358 |
-
for output in outputs:
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
#
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
# Use the first generated code as primary output
|
| 372 |
-
primary_code = generated_codes[0]
|
| 373 |
|
| 374 |
# Calculate metrics if reference code is provided
|
| 375 |
metrics_output = ""
|
| 376 |
bleu_output = ""
|
| 377 |
|
| 378 |
-
if reference_code and reference_code.strip():
|
| 379 |
-
#
|
| 380 |
-
|
|
|
|
|
|
|
| 381 |
|
| 382 |
bleu_output = f"""π BLEU Scores:
|
| 383 |
ββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -410,6 +459,13 @@ def generate_code_from_pseudo(pseudo_code, max_length, temperature, top_k, top_p
|
|
| 410 |
π Sequences Generated: {num_sequences}
|
| 411 |
π’ Output Length: {len(primary_code)} characters
|
| 412 |
ββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
"""
|
| 414 |
else:
|
| 415 |
metrics_output = f"""β±οΈ Generation Time: {generation_time:.2f}s
|
|
|
|
| 353 |
|
| 354 |
generation_time = time.time() - start_time
|
| 355 |
|
| 356 |
+
# Decode all sequences with error handling
|
| 357 |
generated_codes = []
|
| 358 |
+
for i, output in enumerate(outputs):
|
| 359 |
+
try:
|
| 360 |
+
# Ensure output is valid tensor and contains valid token IDs
|
| 361 |
+
if output is None:
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
# Convert to list and filter out None values
|
| 365 |
+
if hasattr(output, 'tolist'):
|
| 366 |
+
token_ids = output.tolist()
|
| 367 |
+
else:
|
| 368 |
+
token_ids = output
|
| 369 |
+
|
| 370 |
+
# Filter out None values and ensure all are integers
|
| 371 |
+
valid_tokens = []
|
| 372 |
+
for token in token_ids:
|
| 373 |
+
if token is not None and isinstance(token, (int, float)):
|
| 374 |
+
valid_tokens.append(int(token))
|
| 375 |
+
|
| 376 |
+
if not valid_tokens:
|
| 377 |
+
generated_codes.append(f"# Generation {i+1} failed: No valid tokens")
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
# Decode with skip_special_tokens=True for cleaner output
|
| 381 |
+
try:
|
| 382 |
+
generated = loaded_tokenizer.decode(valid_tokens, skip_special_tokens=False)
|
| 383 |
+
except Exception as decode_error:
|
| 384 |
+
# Fallback: try with skip_special_tokens=True
|
| 385 |
+
try:
|
| 386 |
+
generated = loaded_tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
| 387 |
+
except Exception as decode_error2:
|
| 388 |
+
# Last resort: convert tokens to string manually
|
| 389 |
+
generated = f"# Decode failed: {str(decode_error2)}"
|
| 390 |
+
|
| 391 |
+
# Handle None result from decode
|
| 392 |
+
if generated is None:
|
| 393 |
+
generated = f"# Generation {i+1}: Decode returned None"
|
| 394 |
+
|
| 395 |
+
# Extract code part
|
| 396 |
+
if '<CODE>' in generated:
|
| 397 |
+
code = generated.split('<CODE>')[-1].strip()
|
| 398 |
+
# Remove special tokens
|
| 399 |
+
code = code.replace('<PAD>', '').replace('<SEP>', '').replace('</s>', '').replace('<s>', '').strip()
|
| 400 |
+
else:
|
| 401 |
+
code = generated.strip()
|
| 402 |
+
|
| 403 |
+
# Ensure we have some content
|
| 404 |
+
if not code or code.isspace():
|
| 405 |
+
code = f"# Generated sequence {i+1} was empty"
|
| 406 |
+
|
| 407 |
+
generated_codes.append(code)
|
| 408 |
+
|
| 409 |
+
except Exception as decode_error:
|
| 410 |
+
# Handle any other decoding errors
|
| 411 |
+
error_msg = f"# Error decoding sequence {i+1}: {str(decode_error)}"
|
| 412 |
+
generated_codes.append(error_msg)
|
| 413 |
+
|
| 414 |
+
# Ensure we have at least one result
|
| 415 |
+
if not generated_codes:
|
| 416 |
+
generated_codes = ["# No valid generations produced"]
|
| 417 |
|
| 418 |
# Use the first generated code as primary output
|
| 419 |
+
primary_code = generated_codes[0] if generated_codes else "# No code generated"
|
| 420 |
|
| 421 |
# Calculate metrics if reference code is provided
|
| 422 |
metrics_output = ""
|
| 423 |
bleu_output = ""
|
| 424 |
|
| 425 |
+
if reference_code and reference_code.strip() and not primary_code.startswith('#'):
|
| 426 |
+
# Only calculate metrics if we have valid generated code (not error messages)
|
| 427 |
+
try:
|
| 428 |
+
# Calculate BLEU scores
|
| 429 |
+
bleu_1, bleu_2, bleu_3, bleu_4 = calculate_bleu_score(reference_code, primary_code)
|
| 430 |
|
| 431 |
bleu_output = f"""π BLEU Scores:
|
| 432 |
ββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 459 |
π Sequences Generated: {num_sequences}
|
| 460 |
π’ Output Length: {len(primary_code)} characters
|
| 461 |
ββββββββββββββββββββββββββββββββββββββββ
|
| 462 |
+
"""
|
| 463 |
+
except Exception as metrics_error:
|
| 464 |
+
metrics_output = f"""β οΈ Metrics calculation failed: {str(metrics_error)}
|
| 465 |
+
|
| 466 |
+
β±οΈ Generation Time: {generation_time:.2f}s
|
| 467 |
+
π Sequences Generated: {num_sequences}
|
| 468 |
+
π’ Output Length: {len(primary_code)} characters
|
| 469 |
"""
|
| 470 |
else:
|
| 471 |
metrics_output = f"""β±οΈ Generation Time: {generation_time:.2f}s
|