Akis Giannoukos commited on
Commit
7ee0100
·
1 Parent(s): 8b938f4

Implement error suppression for TorchInductor in generation functions to enhance stability across environments.

Browse files
Files changed (1) hide show
  1. app.py +55 -18
app.py CHANGED
@@ -343,16 +343,36 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
343
  {"role": "user", "content": combined_prompt},
344
  ]
345
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
346
- gen = pipe(
347
- prompt,
348
- max_new_tokens=96,
349
- temperature=0.7,
350
- do_sample=True,
351
- top_p=0.9,
352
- top_k=50,
353
- pad_token_id=tokenizer.eos_token_id,
354
- return_full_text=False,
355
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  reply = gen[0]["generated_text"].strip()
357
  # Ensure it's a single concise question/sentence
358
  if len(reply) > 300:
@@ -385,14 +405,31 @@ def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str,
385
  ]
386
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
387
  # Use deterministic decoding to avoid CUDA sampling edge cases on some models
388
- gen = pipe(
389
- prompt,
390
- max_new_tokens=256,
391
- temperature=0.0,
392
- do_sample=False,
393
- pad_token_id=tokenizer.eos_token_id,
394
- return_full_text=False,
395
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  out_text = gen[0]["generated_text"]
397
  parsed = safe_json_extract(out_text)
398
 
 
343
  {"role": "user", "content": combined_prompt},
344
  ]
345
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
346
+ # Avoid TorchInductor graph capture issues on some environments
347
+ try:
348
+ import torch._dynamo as _dynamo # type: ignore
349
+ except Exception:
350
+ _dynamo = None
351
+ if _dynamo is not None:
352
+ _dynamo.config.suppress_errors = True # best-effort safe fallback
353
+ if hasattr(torch, "_dynamo"):
354
+ with torch._dynamo.disable(): # type: ignore[attr-defined]
355
+ gen = pipe(
356
+ prompt,
357
+ max_new_tokens=96,
358
+ temperature=0.7,
359
+ do_sample=True,
360
+ top_p=0.9,
361
+ top_k=50,
362
+ pad_token_id=tokenizer.eos_token_id,
363
+ return_full_text=False,
364
+ )
365
+ else:
366
+ gen = pipe(
367
+ prompt,
368
+ max_new_tokens=96,
369
+ temperature=0.7,
370
+ do_sample=True,
371
+ top_p=0.9,
372
+ top_k=50,
373
+ pad_token_id=tokenizer.eos_token_id,
374
+ return_full_text=False,
375
+ )
376
  reply = gen[0]["generated_text"].strip()
377
  # Ensure it's a single concise question/sentence
378
  if len(reply) > 300:
 
405
  ]
406
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
407
  # Use deterministic decoding to avoid CUDA sampling edge cases on some models
408
+ try:
409
+ import torch._dynamo as _dynamo # type: ignore
410
+ except Exception:
411
+ _dynamo = None
412
+ if _dynamo is not None:
413
+ _dynamo.config.suppress_errors = True
414
+ if hasattr(torch, "_dynamo"):
415
+ with torch._dynamo.disable(): # type: ignore[attr-defined]
416
+ gen = pipe(
417
+ prompt,
418
+ max_new_tokens=256,
419
+ temperature=0.0,
420
+ do_sample=False,
421
+ pad_token_id=tokenizer.eos_token_id,
422
+ return_full_text=False,
423
+ )
424
+ else:
425
+ gen = pipe(
426
+ prompt,
427
+ max_new_tokens=256,
428
+ temperature=0.0,
429
+ do_sample=False,
430
+ pad_token_id=tokenizer.eos_token_id,
431
+ return_full_text=False,
432
+ )
433
  out_text = gen[0]["generated_text"]
434
  parsed = safe_json_extract(out_text)
435