Spaces:
Sleeping
Sleeping
Akis Giannoukos
commited on
Commit
·
5731404
1
Parent(s):
17f0761
Added dynamic dtype selection and improved decoding parameters
Browse files
app.py
CHANGED
|
@@ -76,12 +76,18 @@ def get_textgen_pipeline():
|
|
| 76 |
global _gen_pipe
|
| 77 |
if _gen_pipe is None:
|
| 78 |
# Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
_gen_pipe = pipeline(
|
| 80 |
task="text-generation",
|
| 81 |
model=current_model_id,
|
| 82 |
tokenizer=current_model_id,
|
| 83 |
device=_hf_device(),
|
| 84 |
-
torch_dtype=
|
| 85 |
)
|
| 86 |
return _gen_pipe
|
| 87 |
|
|
@@ -275,6 +281,8 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
|
|
| 275 |
max_new_tokens=96,
|
| 276 |
temperature=0.7,
|
| 277 |
do_sample=True,
|
|
|
|
|
|
|
| 278 |
pad_token_id=tokenizer.eos_token_id,
|
| 279 |
return_full_text=False,
|
| 280 |
)
|
|
@@ -309,11 +317,12 @@ def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str,
|
|
| 309 |
{"role": "user", "content": combined_prompt},
|
| 310 |
]
|
| 311 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
| 312 |
gen = pipe(
|
| 313 |
prompt,
|
| 314 |
max_new_tokens=256,
|
| 315 |
-
temperature=0.
|
| 316 |
-
do_sample=
|
| 317 |
pad_token_id=tokenizer.eos_token_id,
|
| 318 |
return_full_text=False,
|
| 319 |
)
|
|
@@ -438,12 +447,14 @@ def process_turn(
|
|
| 438 |
chat_history: List[Tuple[str, str]],
|
| 439 |
threshold: float,
|
| 440 |
tts_enabled: bool,
|
| 441 |
-
finished: bool,
|
| 442 |
-
turns: int,
|
| 443 |
prev_scores: Dict[str, Any],
|
| 444 |
prev_meta: Dict[str, Any],
|
| 445 |
):
|
| 446 |
# If already finished, do nothing
|
|
|
|
|
|
|
| 447 |
if finished:
|
| 448 |
return (
|
| 449 |
chat_history,
|
|
|
|
| 76 |
global _gen_pipe
|
| 77 |
if _gen_pipe is None:
|
| 78 |
# Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
|
| 79 |
+
if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
|
| 80 |
+
_dtype = torch.bfloat16
|
| 81 |
+
elif torch.cuda.is_available():
|
| 82 |
+
_dtype = torch.float16
|
| 83 |
+
else:
|
| 84 |
+
_dtype = torch.float32
|
| 85 |
_gen_pipe = pipeline(
|
| 86 |
task="text-generation",
|
| 87 |
model=current_model_id,
|
| 88 |
tokenizer=current_model_id,
|
| 89 |
device=_hf_device(),
|
| 90 |
+
torch_dtype=_dtype,
|
| 91 |
)
|
| 92 |
return _gen_pipe
|
| 93 |
|
|
|
|
| 281 |
max_new_tokens=96,
|
| 282 |
temperature=0.7,
|
| 283 |
do_sample=True,
|
| 284 |
+
top_p=0.9,
|
| 285 |
+
top_k=50,
|
| 286 |
pad_token_id=tokenizer.eos_token_id,
|
| 287 |
return_full_text=False,
|
| 288 |
)
|
|
|
|
| 317 |
{"role": "user", "content": combined_prompt},
|
| 318 |
]
|
| 319 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 320 |
+
# Use deterministic decoding to avoid CUDA sampling edge cases on some models
|
| 321 |
gen = pipe(
|
| 322 |
prompt,
|
| 323 |
max_new_tokens=256,
|
| 324 |
+
temperature=0.0,
|
| 325 |
+
do_sample=False,
|
| 326 |
pad_token_id=tokenizer.eos_token_id,
|
| 327 |
return_full_text=False,
|
| 328 |
)
|
|
|
|
| 447 |
chat_history: List[Tuple[str, str]],
|
| 448 |
threshold: float,
|
| 449 |
tts_enabled: bool,
|
| 450 |
+
finished: Optional[bool],
|
| 451 |
+
turns: Optional[int],
|
| 452 |
prev_scores: Dict[str, Any],
|
| 453 |
prev_meta: Dict[str, Any],
|
| 454 |
):
|
| 455 |
# If already finished, do nothing
|
| 456 |
+
finished = bool(finished) if finished is not None else False
|
| 457 |
+
turns = int(turns) if isinstance(turns, int) else 0
|
| 458 |
if finished:
|
| 459 |
return (
|
| 460 |
chat_history,
|