hamxaameer commited on
Commit
d99cd3e
Β·
verified Β·
1 Parent(s): 2711df5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -17
app.py CHANGED
@@ -353,31 +353,80 @@ def generate_code_from_pseudo(pseudo_code, max_length, temperature, top_k, top_p
353
 
354
  generation_time = time.time() - start_time
355
 
356
- # Decode all sequences
357
  generated_codes = []
358
- for output in outputs:
359
- generated = loaded_tokenizer.decode(output, skip_special_tokens=False)
360
-
361
- # Extract code part
362
- if '<CODE>' in generated:
363
- code = generated.split('<CODE>')[-1].strip()
364
- # Remove special tokens
365
- code = code.replace('<PAD>', '').replace('<SEP>', '').strip()
366
- else:
367
- code = generated
368
-
369
- generated_codes.append(code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  # Use the first generated code as primary output
372
- primary_code = generated_codes[0]
373
 
374
  # Calculate metrics if reference code is provided
375
  metrics_output = ""
376
  bleu_output = ""
377
 
378
- if reference_code and reference_code.strip():
379
- # Calculate BLEU scores
380
- bleu_1, bleu_2, bleu_3, bleu_4 = calculate_bleu_score(reference_code, primary_code)
 
 
381
 
382
  bleu_output = f"""πŸ“Š BLEU Scores:
383
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@@ -410,6 +459,13 @@ def generate_code_from_pseudo(pseudo_code, max_length, temperature, top_k, top_p
410
  πŸ“ Sequences Generated: {num_sequences}
411
  πŸ”’ Output Length: {len(primary_code)} characters
412
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 
 
 
 
 
 
 
413
  """
414
  else:
415
  metrics_output = f"""⏱️ Generation Time: {generation_time:.2f}s
 
353
 
354
  generation_time = time.time() - start_time
355
 
356
+ # Decode all sequences with error handling
357
  generated_codes = []
358
+ for i, output in enumerate(outputs):
359
+ try:
360
+ # Ensure output is valid tensor and contains valid token IDs
361
+ if output is None:
362
+ continue
363
+
364
+ # Convert to list and filter out None values
365
+ if hasattr(output, 'tolist'):
366
+ token_ids = output.tolist()
367
+ else:
368
+ token_ids = output
369
+
370
+ # Filter out None values and ensure all are integers
371
+ valid_tokens = []
372
+ for token in token_ids:
373
+ if token is not None and isinstance(token, (int, float)):
374
+ valid_tokens.append(int(token))
375
+
376
+ if not valid_tokens:
377
+ generated_codes.append(f"# Generation {i+1} failed: No valid tokens")
378
+ continue
379
+
380
+ # Decode with skip_special_tokens=True for cleaner output
381
+ try:
382
+ generated = loaded_tokenizer.decode(valid_tokens, skip_special_tokens=False)
383
+ except Exception as decode_error:
384
+ # Fallback: try with skip_special_tokens=True
385
+ try:
386
+ generated = loaded_tokenizer.decode(valid_tokens, skip_special_tokens=True)
387
+ except Exception as decode_error2:
388
+ # Last resort: convert tokens to string manually
389
+ generated = f"# Decode failed: {str(decode_error2)}"
390
+
391
+ # Handle None result from decode
392
+ if generated is None:
393
+ generated = f"# Generation {i+1}: Decode returned None"
394
+
395
+ # Extract code part
396
+ if '<CODE>' in generated:
397
+ code = generated.split('<CODE>')[-1].strip()
398
+ # Remove special tokens
399
+ code = code.replace('<PAD>', '').replace('<SEP>', '').replace('</s>', '').replace('<s>', '').strip()
400
+ else:
401
+ code = generated.strip()
402
+
403
+ # Ensure we have some content
404
+ if not code or code.isspace():
405
+ code = f"# Generated sequence {i+1} was empty"
406
+
407
+ generated_codes.append(code)
408
+
409
+ except Exception as decode_error:
410
+ # Handle any other decoding errors
411
+ error_msg = f"# Error decoding sequence {i+1}: {str(decode_error)}"
412
+ generated_codes.append(error_msg)
413
+
414
+ # Ensure we have at least one result
415
+ if not generated_codes:
416
+ generated_codes = ["# No valid generations produced"]
417
 
418
  # Use the first generated code as primary output
419
+ primary_code = generated_codes[0] if generated_codes else "# No code generated"
420
 
421
  # Calculate metrics if reference code is provided
422
  metrics_output = ""
423
  bleu_output = ""
424
 
425
+ if reference_code and reference_code.strip() and not primary_code.startswith('#'):
426
+ # Only calculate metrics if we have valid generated code (not error messages)
427
+ try:
428
+ # Calculate BLEU scores
429
+ bleu_1, bleu_2, bleu_3, bleu_4 = calculate_bleu_score(reference_code, primary_code)
430
 
431
  bleu_output = f"""πŸ“Š BLEU Scores:
432
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 
459
  πŸ“ Sequences Generated: {num_sequences}
460
  πŸ”’ Output Length: {len(primary_code)} characters
461
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
462
+ """
463
+ except Exception as metrics_error:
464
+ metrics_output = f"""⚠️ Metrics calculation failed: {str(metrics_error)}
465
+
466
+ ⏱️ Generation Time: {generation_time:.2f}s
467
+ πŸ“ Sequences Generated: {num_sequences}
468
+ πŸ”’ Output Length: {len(primary_code)} characters
469
  """
470
  else:
471
  metrics_output = f"""⏱️ Generation Time: {generation_time:.2f}s