Akis Giannoukos commited on
Commit
90061b0
·
1 Parent(s): 7ee0100

Refactor generation functions to utilize a safe wrapper for HF pipeline calls, improving error handling and stability.

Browse files
Files changed (1) hide show
  1. app.py +32 -46
app.py CHANGED
@@ -91,6 +91,18 @@ def get_textgen_pipeline():
91
  )
92
  return _gen_pipe
93
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def set_current_model_id(new_model_id: str) -> str:
96
  global current_model_id, _gen_pipe
@@ -348,31 +360,17 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
348
  import torch._dynamo as _dynamo # type: ignore
349
  except Exception:
350
  _dynamo = None
351
- if _dynamo is not None:
352
- _dynamo.config.suppress_errors = True # best-effort safe fallback
353
- if hasattr(torch, "_dynamo"):
354
- with torch._dynamo.disable(): # type: ignore[attr-defined]
355
- gen = pipe(
356
- prompt,
357
- max_new_tokens=96,
358
- temperature=0.7,
359
- do_sample=True,
360
- top_p=0.9,
361
- top_k=50,
362
- pad_token_id=tokenizer.eos_token_id,
363
- return_full_text=False,
364
- )
365
- else:
366
- gen = pipe(
367
- prompt,
368
- max_new_tokens=96,
369
- temperature=0.7,
370
- do_sample=True,
371
- top_p=0.9,
372
- top_k=50,
373
- pad_token_id=tokenizer.eos_token_id,
374
- return_full_text=False,
375
- )
376
  reply = gen[0]["generated_text"].strip()
377
  # Ensure it's a single concise question/sentence
378
  if len(reply) > 300:
@@ -409,27 +407,15 @@ def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str,
409
  import torch._dynamo as _dynamo # type: ignore
410
  except Exception:
411
  _dynamo = None
412
- if _dynamo is not None:
413
- _dynamo.config.suppress_errors = True
414
- if hasattr(torch, "_dynamo"):
415
- with torch._dynamo.disable(): # type: ignore[attr-defined]
416
- gen = pipe(
417
- prompt,
418
- max_new_tokens=256,
419
- temperature=0.0,
420
- do_sample=False,
421
- pad_token_id=tokenizer.eos_token_id,
422
- return_full_text=False,
423
- )
424
- else:
425
- gen = pipe(
426
- prompt,
427
- max_new_tokens=256,
428
- temperature=0.0,
429
- do_sample=False,
430
- pad_token_id=tokenizer.eos_token_id,
431
- return_full_text=False,
432
- )
433
  out_text = gen[0]["generated_text"]
434
  parsed = safe_json_extract(out_text)
435
 
 
91
  )
92
  return _gen_pipe
93
 
94
+ def _safe_hf_generate(pipe, prompt: str, **gen_kwargs):
95
+ """Call HF generate pipeline with best-effort fallbacks to avoid TorchDynamo/Inductor issues."""
96
+ try:
97
+ return pipe(prompt, **gen_kwargs)
98
+ except Exception:
99
+ # Best-effort: disable dynamo via env and retry once
100
+ try:
101
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
102
+ except Exception:
103
+ pass
104
+ return pipe(prompt, **gen_kwargs)
105
+
106
 
107
  def set_current_model_id(new_model_id: str) -> str:
108
  global current_model_id, _gen_pipe
 
360
  import torch._dynamo as _dynamo # type: ignore
361
  except Exception:
362
  _dynamo = None
363
+ gen = _safe_hf_generate(
364
+ pipe,
365
+ prompt,
366
+ max_new_tokens=96,
367
+ temperature=0.7,
368
+ do_sample=True,
369
+ top_p=0.9,
370
+ top_k=50,
371
+ pad_token_id=tokenizer.eos_token_id,
372
+ return_full_text=False,
373
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  reply = gen[0]["generated_text"].strip()
375
  # Ensure it's a single concise question/sentence
376
  if len(reply) > 300:
 
407
  import torch._dynamo as _dynamo # type: ignore
408
  except Exception:
409
  _dynamo = None
410
+ gen = _safe_hf_generate(
411
+ pipe,
412
+ prompt,
413
+ max_new_tokens=256,
414
+ temperature=0.0,
415
+ do_sample=False,
416
+ pad_token_id=tokenizer.eos_token_id,
417
+ return_full_text=False,
418
+ )
 
 
 
 
 
 
 
 
 
 
 
 
419
  out_text = gen[0]["generated_text"]
420
  parsed = safe_json_extract(out_text)
421