Akis Giannoukos commited on
Commit
5731404
·
1 Parent(s): 17f0761

Added dynamic dtype selection and improved decoding parameters

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -76,12 +76,18 @@ def get_textgen_pipeline():
76
  global _gen_pipe
77
  if _gen_pipe is None:
78
  # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
 
 
 
 
 
 
79
  _gen_pipe = pipeline(
80
  task="text-generation",
81
  model=current_model_id,
82
  tokenizer=current_model_id,
83
  device=_hf_device(),
84
- torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
85
  )
86
  return _gen_pipe
87
 
@@ -275,6 +281,8 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
275
  max_new_tokens=96,
276
  temperature=0.7,
277
  do_sample=True,
 
 
278
  pad_token_id=tokenizer.eos_token_id,
279
  return_full_text=False,
280
  )
@@ -309,11 +317,12 @@ def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str,
309
  {"role": "user", "content": combined_prompt},
310
  ]
311
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
312
  gen = pipe(
313
  prompt,
314
  max_new_tokens=256,
315
- temperature=0.2,
316
- do_sample=True,
317
  pad_token_id=tokenizer.eos_token_id,
318
  return_full_text=False,
319
  )
@@ -438,12 +447,14 @@ def process_turn(
438
  chat_history: List[Tuple[str, str]],
439
  threshold: float,
440
  tts_enabled: bool,
441
- finished: bool,
442
- turns: int,
443
  prev_scores: Dict[str, Any],
444
  prev_meta: Dict[str, Any],
445
  ):
446
  # If already finished, do nothing
 
 
447
  if finished:
448
  return (
449
  chat_history,
 
76
  global _gen_pipe
77
  if _gen_pipe is None:
78
  # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
79
+ if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
80
+ _dtype = torch.bfloat16
81
+ elif torch.cuda.is_available():
82
+ _dtype = torch.float16
83
+ else:
84
+ _dtype = torch.float32
85
  _gen_pipe = pipeline(
86
  task="text-generation",
87
  model=current_model_id,
88
  tokenizer=current_model_id,
89
  device=_hf_device(),
90
+ torch_dtype=_dtype,
91
  )
92
  return _gen_pipe
93
 
 
281
  max_new_tokens=96,
282
  temperature=0.7,
283
  do_sample=True,
284
+ top_p=0.9,
285
+ top_k=50,
286
  pad_token_id=tokenizer.eos_token_id,
287
  return_full_text=False,
288
  )
 
317
  {"role": "user", "content": combined_prompt},
318
  ]
319
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
320
+ # Use deterministic decoding to avoid CUDA sampling edge cases on some models
321
  gen = pipe(
322
  prompt,
323
  max_new_tokens=256,
324
+ temperature=0.0,
325
+ do_sample=False,
326
  pad_token_id=tokenizer.eos_token_id,
327
  return_full_text=False,
328
  )
 
447
  chat_history: List[Tuple[str, str]],
448
  threshold: float,
449
  tts_enabled: bool,
450
+ finished: Optional[bool],
451
+ turns: Optional[int],
452
  prev_scores: Dict[str, Any],
453
  prev_meta: Dict[str, Any],
454
  ):
455
  # If already finished, do nothing
456
+ finished = bool(finished) if finished is not None else False
457
+ turns = int(turns) if isinstance(turns, int) else 0
458
  if finished:
459
  return (
460
  chat_history,