Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,6 +18,11 @@ import filelock
|
|
| 18 |
import glob
|
| 19 |
import json
|
| 20 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
from gradio_client.documentation import document, set_documentation_group
|
| 23 |
|
|
@@ -278,455 +283,6 @@ path_markdown = """
|
|
| 278 |
|
| 279 |
|
| 280 |
|
| 281 |
-
|
| 282 |
-
def custom_hf_model_weights_iterator(
|
| 283 |
-
model_name_or_path: str,
|
| 284 |
-
cache_dir: Optional[str] = None,
|
| 285 |
-
use_np_cache: bool = False,
|
| 286 |
-
) -> Iterator[Tuple[str, torch.Tensor]]:
|
| 287 |
-
# ! if use vllm==0.1.4, use this to augment hf_model_weights_iterator loader
|
| 288 |
-
from vllm.model_executor.weight_utils import Disabledtqdm
|
| 289 |
-
# Prepare file lock directory to prevent multiple processes from
|
| 290 |
-
# downloading the same model weights at the same time.
|
| 291 |
-
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
| 292 |
-
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
| 293 |
-
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
| 294 |
-
|
| 295 |
-
# Download model weights from huggingface.
|
| 296 |
-
is_local = os.path.isdir(model_name_or_path)
|
| 297 |
-
if not is_local:
|
| 298 |
-
with lock:
|
| 299 |
-
hf_folder = snapshot_download(model_name_or_path,
|
| 300 |
-
allow_patterns="*.bin",
|
| 301 |
-
cache_dir=cache_dir,
|
| 302 |
-
local_files_only=True,
|
| 303 |
-
tqdm_class=Disabledtqdm)
|
| 304 |
-
else:
|
| 305 |
-
hf_folder = model_name_or_path
|
| 306 |
-
|
| 307 |
-
hf_bin_files = [
|
| 308 |
-
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
|
| 309 |
-
if not x.endswith("training_args.bin")
|
| 310 |
-
]
|
| 311 |
-
hf_safetensors_files = [
|
| 312 |
-
x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors"))
|
| 313 |
-
if not x.endswith("training_args.bin")
|
| 314 |
-
]
|
| 315 |
-
|
| 316 |
-
if use_np_cache:
|
| 317 |
-
# Convert the model weights from torch tensors to numpy arrays for
|
| 318 |
-
# faster loading.
|
| 319 |
-
np_folder = os.path.join(hf_folder, "np")
|
| 320 |
-
os.makedirs(np_folder, exist_ok=True)
|
| 321 |
-
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
| 322 |
-
with lock:
|
| 323 |
-
if not os.path.exists(weight_names_file):
|
| 324 |
-
weight_names = []
|
| 325 |
-
for bin_file in hf_bin_files:
|
| 326 |
-
state = torch.load(bin_file, map_location="cpu")
|
| 327 |
-
for name, param in state.items():
|
| 328 |
-
param_path = os.path.join(np_folder, name)
|
| 329 |
-
with open(param_path, "wb") as f:
|
| 330 |
-
np.save(f, param.cpu().detach().numpy())
|
| 331 |
-
weight_names.append(name)
|
| 332 |
-
with open(weight_names_file, "w") as f:
|
| 333 |
-
json.dump(weight_names, f)
|
| 334 |
-
|
| 335 |
-
with open(weight_names_file, "r") as f:
|
| 336 |
-
weight_names = json.load(f)
|
| 337 |
-
|
| 338 |
-
for name in weight_names:
|
| 339 |
-
param_path = os.path.join(np_folder, name)
|
| 340 |
-
with open(param_path, "rb") as f:
|
| 341 |
-
param = np.load(f)
|
| 342 |
-
yield name, torch.from_numpy(param)
|
| 343 |
-
else:
|
| 344 |
-
if len(hf_bin_files) > 0:
|
| 345 |
-
print(F'Load bin files: {hf_bin_files}')
|
| 346 |
-
for bin_file in hf_bin_files:
|
| 347 |
-
state = torch.load(bin_file, map_location="cpu")
|
| 348 |
-
for name, param in state.items():
|
| 349 |
-
yield name, param
|
| 350 |
-
del state
|
| 351 |
-
torch.cuda.empty_cache()
|
| 352 |
-
elif len(hf_safetensors_files) > 0:
|
| 353 |
-
print(F'Load safetensor files: {hf_safetensors_files}')
|
| 354 |
-
from safetensors.torch import load_file
|
| 355 |
-
for safe_file in hf_safetensors_files:
|
| 356 |
-
# state = torch.load(bin_file, map_location="cpu")
|
| 357 |
-
state = load_file(safe_file)
|
| 358 |
-
for name, param in state.items():
|
| 359 |
-
yield name, param
|
| 360 |
-
del state
|
| 361 |
-
torch.cuda.empty_cache()
|
| 362 |
-
else:
|
| 363 |
-
raise ValueError(f'no files available either bin or safe')
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
| 367 |
-
"""convert PySafeSlice object from safetensors to torch.Tensor
|
| 368 |
-
|
| 369 |
-
PySafeSlice object supports indexing, which is done before loading the
|
| 370 |
-
actual tensor and can reduce the amount of memory being read into the
|
| 371 |
-
memory. However, it does not support more advanced functionalities
|
| 372 |
-
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
| 373 |
-
tensor with these more complicated operators, we need to convert to
|
| 374 |
-
tensor first.
|
| 375 |
-
"""
|
| 376 |
-
if not isinstance(x, torch.Tensor):
|
| 377 |
-
x = x[:]
|
| 378 |
-
return x
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
def load_padded_tensor_parallel_vocab(
|
| 382 |
-
param: torch.Tensor,
|
| 383 |
-
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
| 384 |
-
tensor_model_parallel_rank: int,
|
| 385 |
-
) -> None:
|
| 386 |
-
shard_size = param.shape[0]
|
| 387 |
-
start_idx = tensor_model_parallel_rank * shard_size
|
| 388 |
-
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
| 389 |
-
loaded_weight = loaded_weight[start_idx:end_idx]
|
| 390 |
-
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
| 391 |
-
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
def llama_load_weights(
|
| 395 |
-
self,
|
| 396 |
-
model_name_or_path: str,
|
| 397 |
-
cache_dir: Optional[str] = None,
|
| 398 |
-
use_np_cache: bool = False,
|
| 399 |
-
load_format: str = "auto",
|
| 400 |
-
revision: Optional[str] = None
|
| 401 |
-
):
|
| 402 |
-
# if use vllm==0.1.4
|
| 403 |
-
from vllm.model_executor.weight_utils import (
|
| 404 |
-
load_tensor_parallel_weights
|
| 405 |
-
)
|
| 406 |
-
from vllm.model_executor.parallel_utils.parallel_state import (
|
| 407 |
-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
| 408 |
-
tp_size = get_tensor_model_parallel_world_size()
|
| 409 |
-
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
| 410 |
-
|
| 411 |
-
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
| 412 |
-
kv_proj_shard_size = (self.config.hidden_size //
|
| 413 |
-
self.config.num_attention_heads *
|
| 414 |
-
getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size)
|
| 415 |
-
attention_weight_specs = [
|
| 416 |
-
# (weight_name, shard_size, offset)
|
| 417 |
-
("q_proj", q_proj_shard_size, 0),
|
| 418 |
-
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
| 419 |
-
("v_proj", kv_proj_shard_size,
|
| 420 |
-
q_proj_shard_size + kv_proj_shard_size),
|
| 421 |
-
]
|
| 422 |
-
state_dict = self.state_dict()
|
| 423 |
-
need_to_load = len(state_dict)
|
| 424 |
-
loaded = 0
|
| 425 |
-
iterator = custom_hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
| 426 |
-
|
| 427 |
-
for name, loaded_weight in iterator:
|
| 428 |
-
if "rotary_emb.inv_freq" in name:
|
| 429 |
-
continue
|
| 430 |
-
|
| 431 |
-
if "embed_tokens" in name or "lm_head" in name:
|
| 432 |
-
param = state_dict[name]
|
| 433 |
-
# Consider padding in the vocab size.
|
| 434 |
-
padded_vocab_size = (param.shape[0] * tp_size)
|
| 435 |
-
# num_extra_rows = padded_vocab_size - self.config.vocab_size
|
| 436 |
-
num_extra_rows = padded_vocab_size - loaded_weight.size(0)
|
| 437 |
-
load_size = loaded_weight.size()
|
| 438 |
-
extra_rows = torch.empty(num_extra_rows,
|
| 439 |
-
loaded_weight.shape[1])
|
| 440 |
-
extra_rows = extra_rows.to(loaded_weight)
|
| 441 |
-
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
| 442 |
-
if num_extra_rows > 0:
|
| 443 |
-
print(f'Add empty to {num_extra_rows} extra row for {name}')
|
| 444 |
-
print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
|
| 445 |
-
|
| 446 |
-
is_attention_weight = False
|
| 447 |
-
for weight_name, shard_size, offset in attention_weight_specs:
|
| 448 |
-
if weight_name not in name or "qkv_proj" in name:
|
| 449 |
-
continue
|
| 450 |
-
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
| 451 |
-
|
| 452 |
-
loaded_weight = loaded_weight[
|
| 453 |
-
shard_size * tensor_model_parallel_rank:shard_size *
|
| 454 |
-
(tensor_model_parallel_rank + 1)]
|
| 455 |
-
param_slice = param.data[offset:offset + shard_size]
|
| 456 |
-
assert param_slice.shape == loaded_weight.shape
|
| 457 |
-
|
| 458 |
-
param_slice.copy_(loaded_weight)
|
| 459 |
-
loaded += 1.0 / 3
|
| 460 |
-
is_attention_weight = True
|
| 461 |
-
break
|
| 462 |
-
if is_attention_weight:
|
| 463 |
-
continue
|
| 464 |
-
|
| 465 |
-
# ! qkv_proj is sharded differently if concatenated into qkv
|
| 466 |
-
# qkv: qqqq kkkk vvvv
|
| 467 |
-
# lweight: qq0qq1 kk0kk1 vv0vv1
|
| 468 |
-
# q_shard_size: hidden_size // tp_size = qq
|
| 469 |
-
# qkv_s0: qq0_kk0_vv0
|
| 470 |
-
# qkv_s1: qq1_kk1_vv1
|
| 471 |
-
if "qkv_proj" in name:
|
| 472 |
-
param = state_dict[name]
|
| 473 |
-
# loaded_weight
|
| 474 |
-
qsize = self.config.hidden_size
|
| 475 |
-
kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
|
| 476 |
-
q_offsets = (
|
| 477 |
-
q_proj_shard_size * tensor_model_parallel_rank,
|
| 478 |
-
q_proj_shard_size * (tensor_model_parallel_rank + 1)
|
| 479 |
-
)
|
| 480 |
-
k_offsets = (
|
| 481 |
-
qsize + kv_proj_shard_size * tensor_model_parallel_rank,
|
| 482 |
-
qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
|
| 483 |
-
)
|
| 484 |
-
v_offsets = (
|
| 485 |
-
qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank,
|
| 486 |
-
qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
|
| 487 |
-
)
|
| 488 |
-
_loaded_weight = torch.cat(
|
| 489 |
-
[
|
| 490 |
-
loaded_weight[q_offsets[0]:q_offsets[1]],
|
| 491 |
-
loaded_weight[k_offsets[0]:k_offsets[1]],
|
| 492 |
-
loaded_weight[v_offsets[0]:v_offsets[1]],
|
| 493 |
-
], 0
|
| 494 |
-
)
|
| 495 |
-
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
|
| 496 |
-
param.data.copy_(_loaded_weight)
|
| 497 |
-
loaded += 1.0
|
| 498 |
-
is_attention_weight = True
|
| 499 |
-
if is_attention_weight:
|
| 500 |
-
continue
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
is_gate_up_weight = False
|
| 504 |
-
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
| 505 |
-
if weight_name not in name or "gate_up_proj" in name:
|
| 506 |
-
continue
|
| 507 |
-
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
| 508 |
-
shard_size = param.shape[0] // 2
|
| 509 |
-
loaded_weight = loaded_weight[
|
| 510 |
-
shard_size * tensor_model_parallel_rank:shard_size *
|
| 511 |
-
(tensor_model_parallel_rank + 1)]
|
| 512 |
-
param_slice = param.data[shard_size * stride_id:shard_size *
|
| 513 |
-
(stride_id + 1)]
|
| 514 |
-
assert param_slice.shape == loaded_weight.shape
|
| 515 |
-
param_slice.copy_(loaded_weight)
|
| 516 |
-
loaded += 1.0 / 2
|
| 517 |
-
is_gate_up_weight = True
|
| 518 |
-
break
|
| 519 |
-
if is_gate_up_weight:
|
| 520 |
-
continue
|
| 521 |
-
|
| 522 |
-
if "gate_up_proj" in name:
|
| 523 |
-
param = state_dict[name]
|
| 524 |
-
shard_size = param.shape[0] // 2
|
| 525 |
-
intermediate_size = self.config.intermediate_size
|
| 526 |
-
g_offsets = (
|
| 527 |
-
shard_size * tensor_model_parallel_rank,
|
| 528 |
-
shard_size * (tensor_model_parallel_rank + 1)
|
| 529 |
-
)
|
| 530 |
-
u_offsets = (
|
| 531 |
-
intermediate_size + shard_size * tensor_model_parallel_rank,
|
| 532 |
-
intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
|
| 533 |
-
)
|
| 534 |
-
_loaded_weight = torch.cat(
|
| 535 |
-
[
|
| 536 |
-
loaded_weight[g_offsets[0]:g_offsets[1]],
|
| 537 |
-
loaded_weight[u_offsets[0]:u_offsets[1]],
|
| 538 |
-
], 0
|
| 539 |
-
)
|
| 540 |
-
assert param.shape == _loaded_weight.shape
|
| 541 |
-
param.data.copy_(_loaded_weight)
|
| 542 |
-
loaded += 1.0
|
| 543 |
-
is_gate_up_weight = True
|
| 544 |
-
if is_gate_up_weight:
|
| 545 |
-
continue
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
param = state_dict[name]
|
| 549 |
-
load_tensor_parallel_weights(param, loaded_weight, name,
|
| 550 |
-
self._column_parallel_weights,
|
| 551 |
-
self._row_parallel_weights,
|
| 552 |
-
tensor_model_parallel_rank)
|
| 553 |
-
loaded += 1
|
| 554 |
-
|
| 555 |
-
if np.abs(loaded - need_to_load) < 0.01:
|
| 556 |
-
print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
|
| 557 |
-
else:
|
| 558 |
-
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
def new_llama_load_weights(
|
| 562 |
-
self,
|
| 563 |
-
model_name_or_path: str,
|
| 564 |
-
cache_dir: Optional[str] = None,
|
| 565 |
-
load_format: str = "auto",
|
| 566 |
-
revision: Optional[str] = None
|
| 567 |
-
):
|
| 568 |
-
# If use newest vllm, not been thoroughly tested yet.
|
| 569 |
-
from vllm.model_executor.weight_utils import (
|
| 570 |
-
load_tensor_parallel_weights, hf_model_weights_iterator
|
| 571 |
-
)
|
| 572 |
-
from vllm.model_executor.parallel_utils.parallel_state import (
|
| 573 |
-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
| 574 |
-
|
| 575 |
-
if self.quant_config is None:
|
| 576 |
-
weight_suffixes = ["weight"]
|
| 577 |
-
else:
|
| 578 |
-
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
| 579 |
-
|
| 580 |
-
column_parallel_weights: List[str] = []
|
| 581 |
-
for layer in self._column_parallel_layers:
|
| 582 |
-
for suffix in weight_suffixes:
|
| 583 |
-
column_parallel_weights.append(f"{layer}.{suffix}")
|
| 584 |
-
row_parallel_weights: List[str] = []
|
| 585 |
-
for layer in self._row_parallel_layers:
|
| 586 |
-
for suffix in weight_suffixes:
|
| 587 |
-
row_parallel_weights.append(f"{layer}.{suffix}")
|
| 588 |
-
|
| 589 |
-
tp_size = get_tensor_model_parallel_world_size()
|
| 590 |
-
tp_rank = get_tensor_model_parallel_rank()
|
| 591 |
-
assert tp_size == 1, f'tensorparallel >=2 not allowed. {tp_size}'
|
| 592 |
-
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
| 593 |
-
num_kv_heads_replicas = max(1,
|
| 594 |
-
tp_size // self.config.num_key_value_heads)
|
| 595 |
-
num_kv_heads_per_gpu = max(1,
|
| 596 |
-
self.config.num_key_value_heads // tp_size)
|
| 597 |
-
kv_proj_shard_size = (self.config.hidden_size //
|
| 598 |
-
self.config.num_attention_heads *
|
| 599 |
-
num_kv_heads_per_gpu)
|
| 600 |
-
attention_weight_specs = [
|
| 601 |
-
# (weight_name, shard_size, offset)
|
| 602 |
-
("q_proj", q_proj_shard_size, 0),
|
| 603 |
-
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
| 604 |
-
("v_proj", kv_proj_shard_size,
|
| 605 |
-
q_proj_shard_size + kv_proj_shard_size),
|
| 606 |
-
]
|
| 607 |
-
state_dict = self.state_dict()
|
| 608 |
-
need_to_load = len(state_dict)
|
| 609 |
-
loaded = 0
|
| 610 |
-
|
| 611 |
-
for name, loaded_weight in hf_model_weights_iterator(
|
| 612 |
-
model_name_or_path, cache_dir, load_format, revision):
|
| 613 |
-
if "rotary_emb.inv_freq" in name:
|
| 614 |
-
continue
|
| 615 |
-
|
| 616 |
-
is_packed = False
|
| 617 |
-
is_transposed = False
|
| 618 |
-
if self.quant_config is not None:
|
| 619 |
-
is_packed = self.quant_config.is_packed(name)
|
| 620 |
-
is_transposed = self.quant_config.is_transposed(name)
|
| 621 |
-
if is_transposed:
|
| 622 |
-
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
| 623 |
-
loaded_weight = loaded_weight.T
|
| 624 |
-
|
| 625 |
-
is_attention_weight = False
|
| 626 |
-
for weight_name, shard_size, offset in attention_weight_specs:
|
| 627 |
-
if weight_name not in name or "qkv_proj" in name:
|
| 628 |
-
continue
|
| 629 |
-
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
| 630 |
-
if is_transposed:
|
| 631 |
-
param = param.T
|
| 632 |
-
|
| 633 |
-
if is_packed:
|
| 634 |
-
shard_size //= self.quant_config.pack_factor
|
| 635 |
-
offset //= self.quant_config.pack_factor
|
| 636 |
-
|
| 637 |
-
if weight_name in ["k_proj", "v_proj"]:
|
| 638 |
-
shard_id = tp_rank // num_kv_heads_replicas
|
| 639 |
-
else:
|
| 640 |
-
shard_id = tp_rank
|
| 641 |
-
loaded_weight = loaded_weight[shard_size *
|
| 642 |
-
shard_id:shard_size *
|
| 643 |
-
(shard_id + 1)]
|
| 644 |
-
param_slice = param.data[offset:offset + shard_size]
|
| 645 |
-
assert param_slice.shape == loaded_weight.shape
|
| 646 |
-
|
| 647 |
-
param_slice.copy_(loaded_weight)
|
| 648 |
-
loaded += 1.0 / 3
|
| 649 |
-
is_attention_weight = True
|
| 650 |
-
break
|
| 651 |
-
if is_attention_weight:
|
| 652 |
-
continue
|
| 653 |
-
|
| 654 |
-
# TODO: need to figure out to do sharding with qkv_proj fused
|
| 655 |
-
|
| 656 |
-
is_gate_up_weight = False
|
| 657 |
-
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
| 658 |
-
if weight_name not in name or "gate_up_proj" in name:
|
| 659 |
-
continue
|
| 660 |
-
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
| 661 |
-
if is_transposed:
|
| 662 |
-
param = param.T
|
| 663 |
-
|
| 664 |
-
shard_size = param.shape[0] // 2
|
| 665 |
-
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
| 666 |
-
(tp_rank + 1)]
|
| 667 |
-
param_slice = param.data[shard_size * stride_id:shard_size *
|
| 668 |
-
(stride_id + 1)]
|
| 669 |
-
assert param_slice.shape == loaded_weight.shape
|
| 670 |
-
param_slice.copy_(loaded_weight)
|
| 671 |
-
loaded += 1.0 / 2
|
| 672 |
-
is_gate_up_weight = True
|
| 673 |
-
break
|
| 674 |
-
if is_gate_up_weight:
|
| 675 |
-
continue
|
| 676 |
-
|
| 677 |
-
# TODO: need to figure out to do sharding with gate_up_proj fused
|
| 678 |
-
|
| 679 |
-
param = state_dict[name]
|
| 680 |
-
if is_transposed:
|
| 681 |
-
param = param.T
|
| 682 |
-
|
| 683 |
-
if "embed_tokens" in name or "lm_head" in name:
|
| 684 |
-
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
| 685 |
-
tp_rank)
|
| 686 |
-
loaded += 1
|
| 687 |
-
continue
|
| 688 |
-
|
| 689 |
-
load_tensor_parallel_weights(param, loaded_weight, name,
|
| 690 |
-
column_parallel_weights,
|
| 691 |
-
row_parallel_weights, tp_rank)
|
| 692 |
-
loaded += 1
|
| 693 |
-
|
| 694 |
-
if np.abs(loaded - need_to_load) < 0.01:
|
| 695 |
-
print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
|
| 696 |
-
else:
|
| 697 |
-
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
| 701 |
-
if not DEBUG:
|
| 702 |
-
|
| 703 |
-
try:
|
| 704 |
-
import vllm
|
| 705 |
-
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
| 706 |
-
from vllm.model_executor.models import LlamaForCausalLM
|
| 707 |
-
|
| 708 |
-
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
| 709 |
-
if vllm.__version__ == "0.1.4":
|
| 710 |
-
LlamaForCausalLM.load_weights = llama_load_weights
|
| 711 |
-
else:
|
| 712 |
-
LlamaForCausalLM.load_weights = new_llama_load_weights
|
| 713 |
-
|
| 714 |
-
if DTYPE == "bfloat16":
|
| 715 |
-
try:
|
| 716 |
-
compute_capability = torch.cuda.get_device_capability()
|
| 717 |
-
if compute_capability[0] < 8:
|
| 718 |
-
gpu_name = torch.cuda.get_device_name()
|
| 719 |
-
print(
|
| 720 |
-
"Bfloat16 is only supported on GPUs with compute capability "
|
| 721 |
-
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
| 722 |
-
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
|
| 723 |
-
DTYPE = "float16"
|
| 724 |
-
except Exception as e:
|
| 725 |
-
print(f'Unable to obtain compute_capability: {e}')
|
| 726 |
-
except Exception as e:
|
| 727 |
-
print(f'Failing import and reconfigure VLLM: {str(e)}')
|
| 728 |
-
|
| 729 |
-
|
| 730 |
# ! ==================================================================
|
| 731 |
|
| 732 |
set_documentation_group("component")
|
|
@@ -734,41 +290,6 @@ set_documentation_group("component")
|
|
| 734 |
|
| 735 |
RES_PRINTED = False
|
| 736 |
|
| 737 |
-
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
|
| 738 |
-
return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}"
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
def llama_chat_multiturn_sys_input_seq_constructor(
|
| 742 |
-
message: str,
|
| 743 |
-
history: List[Tuple[str, str]],
|
| 744 |
-
sys_prompt=SYSTEM_PROMPT_1,
|
| 745 |
-
bos_token=BOS_TOKEN,
|
| 746 |
-
eos_token=EOS_TOKEN,
|
| 747 |
-
include_end_instruct=True,
|
| 748 |
-
):
|
| 749 |
-
"""
|
| 750 |
-
```
|
| 751 |
-
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
| 752 |
-
<bos>[INST] Prompt [/INST] Answer <eos>
|
| 753 |
-
<bos>[INST] Prompt [/INST]
|
| 754 |
-
```
|
| 755 |
-
"""
|
| 756 |
-
text = ''
|
| 757 |
-
end_instr = f" {E_INST}" if include_end_instruct else ""
|
| 758 |
-
for i, (prompt, res) in enumerate(history):
|
| 759 |
-
if i == 0:
|
| 760 |
-
text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt}{end_instr}"
|
| 761 |
-
else:
|
| 762 |
-
text += f"{bos_token}{B_INST} {prompt}{end_instr}"
|
| 763 |
-
|
| 764 |
-
if res is not None:
|
| 765 |
-
text += f" {res} {eos_token} "
|
| 766 |
-
if len(history) == 0 or text.strip() == '':
|
| 767 |
-
text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message}{end_instr}"
|
| 768 |
-
else:
|
| 769 |
-
text += f"{bos_token}{B_INST} {message}{end_instr}"
|
| 770 |
-
return text
|
| 771 |
-
|
| 772 |
|
| 773 |
@document()
|
| 774 |
class ChatBot(gr.Chatbot):
|
|
@@ -966,29 +487,63 @@ def _setup_events(self) -> None:
|
|
| 966 |
)
|
| 967 |
|
| 968 |
# Reconfigure clear_btn to stop and clear text box
|
| 969 |
-
# if self.clear_btn:
|
| 970 |
-
# self.clear_btn.click(
|
| 971 |
-
# lambda: ([], [], None),
|
| 972 |
-
# None,
|
| 973 |
-
# [self.chatbot, self.chatbot_state, self.saved_input],
|
| 974 |
-
# queue=False,
|
| 975 |
-
# api_name=False,
|
| 976 |
-
# cancels=submit_event,
|
| 977 |
-
# )
|
| 978 |
|
| 979 |
|
| 980 |
def _display_input(
|
| 981 |
-
self, message: str, history:
|
| 982 |
-
) ->
|
| 983 |
if message is not None and message.strip() != "":
|
| 984 |
history.append([message, None])
|
| 985 |
return history, history
|
| 986 |
|
| 987 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 988 |
# replace
|
| 989 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
| 990 |
gr.ChatInterface._setup_events = _setup_events
|
| 991 |
gr.ChatInterface._display_input = _display_input
|
|
|
|
| 992 |
|
| 993 |
|
| 994 |
@document()
|
|
@@ -1036,25 +591,6 @@ class CustomTabbedInterface(gr.Blocks):
|
|
| 1036 |
interface.render()
|
| 1037 |
|
| 1038 |
|
| 1039 |
-
|
| 1040 |
-
# def vllm_abort(self: Any):
|
| 1041 |
-
# sh = self.llm_engine.scheduler
|
| 1042 |
-
# for g in (sh.waiting + sh.running + sh.swapped):
|
| 1043 |
-
# sh.abort_seq_group(g.request_id)
|
| 1044 |
-
|
| 1045 |
-
# from vllm.sequence import SequenceStatus
|
| 1046 |
-
# scheduler = self.llm_engine.scheduler
|
| 1047 |
-
# for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
| 1048 |
-
# for seq_group in state_queue:
|
| 1049 |
-
# # if seq_group.request_id == request_id:
|
| 1050 |
-
# # Remove the sequence group from the state queue.
|
| 1051 |
-
# state_queue.remove(seq_group)
|
| 1052 |
-
# for seq in seq_group.seqs:
|
| 1053 |
-
# if seq.is_finished():
|
| 1054 |
-
# continue
|
| 1055 |
-
# scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
def vllm_abort(self):
|
| 1059 |
sh = self.llm_engine.scheduler
|
| 1060 |
for g in (sh.waiting + sh.running + sh.swapped):
|
|
@@ -1231,6 +767,14 @@ def chatml_format(message, history=None, system_prompt=None):
|
|
| 1231 |
return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
|
| 1232 |
|
| 1233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1234 |
def chat_response_stream_multiturn(
|
| 1235 |
message: str,
|
| 1236 |
history: List[Tuple[str, str]],
|
|
@@ -1242,6 +786,9 @@ def chat_response_stream_multiturn(
|
|
| 1242 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
| 1243 |
) -> str:
|
| 1244 |
global LOG_FILE, LOG_PATH
|
|
|
|
|
|
|
|
|
|
| 1245 |
from vllm import LLM, SamplingParams
|
| 1246 |
"""Build multi turn
|
| 1247 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
|
@@ -1274,16 +821,12 @@ def chat_response_stream_multiturn(
|
|
| 1274 |
|
| 1275 |
message_safety = safety_check(message, history=history)
|
| 1276 |
if message_safety is not None:
|
| 1277 |
-
yield message_safety
|
| 1278 |
-
|
| 1279 |
|
| 1280 |
# history will be appended with message later on
|
| 1281 |
|
| 1282 |
-
# full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
|
| 1283 |
-
# message, history, sys_prompt=system_prompt
|
| 1284 |
-
# )
|
| 1285 |
full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
|
| 1286 |
-
# print(full_prompt)
|
| 1287 |
|
| 1288 |
if len(tokenizer.encode(full_prompt, add_special_tokens=False)) >= 4050:
|
| 1289 |
raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
|
|
@@ -1334,6 +877,89 @@ def chat_response_stream_multiturn(
|
|
| 1334 |
if message_safety is not None:
|
| 1335 |
yield message_safety
|
| 1336 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1337 |
|
| 1338 |
|
| 1339 |
def maybe_log_conv_file(current_time, history, message, response, **kwargs):
|
|
@@ -1715,6 +1341,48 @@ CHAT_EXAMPLES = [
|
|
| 1715 |
|
| 1716 |
# performance items
|
| 1717 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1718 |
|
| 1719 |
def launch_demo():
|
| 1720 |
global demo, llm, DEBUG, LOG_FILE
|
|
@@ -1817,7 +1485,7 @@ def launch_demo():
|
|
| 1817 |
gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
|
| 1818 |
gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
|
| 1819 |
gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
|
| 1820 |
-
gr.Textbox(value="
|
| 1821 |
gr.Number(value=0, label='current_time', visible=False),
|
| 1822 |
],
|
| 1823 |
outputs=[
|
|
@@ -1829,11 +1497,13 @@ def launch_demo():
|
|
| 1829 |
description=FILE_UPLOAD_DESCRIPTION,
|
| 1830 |
allow_flagging=False,
|
| 1831 |
examples=[
|
| 1832 |
-
["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "
|
| 1833 |
-
["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "
|
| 1834 |
],
|
| 1835 |
cache_examples=False,
|
| 1836 |
)
|
|
|
|
|
|
|
| 1837 |
|
| 1838 |
demo_chat = gr.ChatInterface(
|
| 1839 |
response_fn,
|
|
@@ -1869,8 +1539,8 @@ def launch_demo():
|
|
| 1869 |
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
| 1870 |
|
| 1871 |
demo = CustomTabbedInterface(
|
| 1872 |
-
interface_list=[demo_chat, demo_file_upload],
|
| 1873 |
-
tab_names=["Chat Interface", "Batch Inference"],
|
| 1874 |
title=f"{model_title}",
|
| 1875 |
description=descriptions,
|
| 1876 |
)
|
|
|
|
| 18 |
import glob
|
| 19 |
import json
|
| 20 |
import time
|
| 21 |
+
from gradio.routes import Request
|
| 22 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
| 23 |
+
from gradio.helpers import special_args
|
| 24 |
+
import anyio
|
| 25 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
| 26 |
|
| 27 |
from gradio_client.documentation import document, set_documentation_group
|
| 28 |
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
# ! ==================================================================
|
| 287 |
|
| 288 |
set_documentation_group("component")
|
|
|
|
| 290 |
|
| 291 |
RES_PRINTED = False
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
@document()
|
| 295 |
class ChatBot(gr.Chatbot):
|
|
|
|
| 487 |
)
|
| 488 |
|
| 489 |
# Reconfigure clear_btn to stop and clear text box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
|
| 492 |
def _display_input(
|
| 493 |
+
self, message: str, history: List[List[Union[str, None]]]
|
| 494 |
+
) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
|
| 495 |
if message is not None and message.strip() != "":
|
| 496 |
history.append([message, None])
|
| 497 |
return history, history
|
| 498 |
|
| 499 |
|
| 500 |
+
async def _stream_fn(
|
| 501 |
+
self,
|
| 502 |
+
message: str,
|
| 503 |
+
history_with_input,
|
| 504 |
+
request: Request,
|
| 505 |
+
*args,
|
| 506 |
+
) -> AsyncGenerator:
|
| 507 |
+
history = history_with_input[:-1]
|
| 508 |
+
inputs, _, _ = special_args(
|
| 509 |
+
self.fn, inputs=[message, history, *args], request=request
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
if self.is_async:
|
| 513 |
+
generator = self.fn(*inputs)
|
| 514 |
+
else:
|
| 515 |
+
generator = await anyio.to_thread.run_sync(
|
| 516 |
+
self.fn, *inputs, limiter=self.limiter
|
| 517 |
+
)
|
| 518 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
| 519 |
+
try:
|
| 520 |
+
first_response = await async_iteration(generator)
|
| 521 |
+
update = history + [[message, first_response]]
|
| 522 |
+
yield update, update
|
| 523 |
+
except StopIteration:
|
| 524 |
+
update = history + [[message, None]]
|
| 525 |
+
yield update, update
|
| 526 |
+
try:
|
| 527 |
+
async for response in generator:
|
| 528 |
+
update = history + [[message, response]]
|
| 529 |
+
yield update, update
|
| 530 |
+
except Exception as e:
|
| 531 |
+
# if "invalid" in str(e):
|
| 532 |
+
# yield history, history
|
| 533 |
+
# raise e
|
| 534 |
+
# else:
|
| 535 |
+
# raise e
|
| 536 |
+
yield history, history
|
| 537 |
+
raise e
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
|
| 542 |
# replace
|
| 543 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
| 544 |
gr.ChatInterface._setup_events = _setup_events
|
| 545 |
gr.ChatInterface._display_input = _display_input
|
| 546 |
+
gr.ChatInterface._stream_fn = _stream_fn
|
| 547 |
|
| 548 |
|
| 549 |
@document()
|
|
|
|
| 591 |
interface.render()
|
| 592 |
|
| 593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
def vllm_abort(self):
|
| 595 |
sh = self.llm_engine.scheduler
|
| 596 |
for g in (sh.waiting + sh.running + sh.swapped):
|
|
|
|
| 767 |
return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
|
| 768 |
|
| 769 |
|
| 770 |
+
def debug_chat_response_stream_multiturn(*args, **kwargs):
|
| 771 |
+
message = "This is a debugging message"
|
| 772 |
+
for i in range(len(message)):
|
| 773 |
+
time.sleep(0.05)
|
| 774 |
+
yield message[:i]
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
|
| 778 |
def chat_response_stream_multiturn(
|
| 779 |
message: str,
|
| 780 |
history: List[Tuple[str, str]],
|
|
|
|
| 786 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
| 787 |
) -> str:
|
| 788 |
global LOG_FILE, LOG_PATH
|
| 789 |
+
if DEBUG:
|
| 790 |
+
yield from debug_chat_response_stream_multiturn()
|
| 791 |
+
return
|
| 792 |
from vllm import LLM, SamplingParams
|
| 793 |
"""Build multi turn
|
| 794 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
|
|
|
| 821 |
|
| 822 |
message_safety = safety_check(message, history=history)
|
| 823 |
if message_safety is not None:
|
| 824 |
+
# yield message_safety
|
| 825 |
+
raise gr.Error(message_safety)
|
| 826 |
|
| 827 |
# history will be appended with message later on
|
| 828 |
|
|
|
|
|
|
|
|
|
|
| 829 |
full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
|
|
|
|
| 830 |
|
| 831 |
if len(tokenizer.encode(full_prompt, add_special_tokens=False)) >= 4050:
|
| 832 |
raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
|
|
|
|
| 877 |
if message_safety is not None:
|
| 878 |
yield message_safety
|
| 879 |
return
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def debug_generate_free_form_stream(message):
|
| 884 |
+
output = " This is a debugging message...."
|
| 885 |
+
for i in range(len(output)):
|
| 886 |
+
time.sleep(0.05)
|
| 887 |
+
yield message + output[:i]
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def generate_free_form_stream(
|
| 891 |
+
message: str,
|
| 892 |
+
temperature: float,
|
| 893 |
+
max_tokens: int,
|
| 894 |
+
frequency_penalty: float,
|
| 895 |
+
presence_penalty: float,
|
| 896 |
+
current_time: Optional[float] = None,
|
| 897 |
+
stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
|
| 898 |
+
) -> str:
|
| 899 |
+
global LOG_FILE, LOG_PATH
|
| 900 |
+
if DEBUG:
|
| 901 |
+
yield from debug_generate_free_form_stream(message)
|
| 902 |
+
return
|
| 903 |
+
from vllm import LLM, SamplingParams
|
| 904 |
+
"""Build multi turn
|
| 905 |
+
"""
|
| 906 |
+
global llm, RES_PRINTED
|
| 907 |
+
assert llm is not None
|
| 908 |
+
tokenizer = llm.get_tokenizer()
|
| 909 |
+
# force removing all
|
| 910 |
+
vllm_abort(llm)
|
| 911 |
+
|
| 912 |
+
temperature = float(temperature)
|
| 913 |
+
frequency_penalty = float(frequency_penalty)
|
| 914 |
+
max_tokens = int(max_tokens)
|
| 915 |
+
|
| 916 |
+
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
| 917 |
+
stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>']))
|
| 918 |
+
|
| 919 |
+
sampling_params = SamplingParams(
|
| 920 |
+
temperature=temperature,
|
| 921 |
+
max_tokens=max_tokens,
|
| 922 |
+
frequency_penalty=frequency_penalty,
|
| 923 |
+
presence_penalty=presence_penalty,
|
| 924 |
+
stop=stop_strings,
|
| 925 |
+
# ignore_eos=True,
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
# full_prompt = message
|
| 929 |
+
if len(message) == 0:
|
| 930 |
+
raise gr.Error("The message cannot be empty!")
|
| 931 |
+
|
| 932 |
+
message_safety = safety_check(message)
|
| 933 |
+
if message_safety is not None:
|
| 934 |
+
raise gr.Error(message_safety)
|
| 935 |
+
|
| 936 |
+
if len(tokenizer.encode(message, add_special_tokens=False)) >= 4050:
|
| 937 |
+
raise gr.Error(f"Prompt is too long!")
|
| 938 |
+
|
| 939 |
+
cur_out = None
|
| 940 |
+
for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
|
| 941 |
+
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
| 942 |
+
# optionally check safety, and respond
|
| 943 |
+
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
|
| 944 |
+
message_safety = safety_check(cur_out, history=None)
|
| 945 |
+
if message_safety is not None:
|
| 946 |
+
raise gr.Error(message_safety)
|
| 947 |
+
yield message + cur_out
|
| 948 |
+
assert len(gen) == 1, f'{gen}'
|
| 949 |
+
item = next(iter(gen.values()))
|
| 950 |
+
cur_out = item.outputs[0].text
|
| 951 |
+
#cur_out = "Our system is under maintenance, will be back soon!"
|
| 952 |
+
if j >= max_tokens - 2:
|
| 953 |
+
gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
|
| 954 |
+
|
| 955 |
+
if cur_out is not None:
|
| 956 |
+
yield message + cur_out
|
| 957 |
+
|
| 958 |
+
message_safety = safety_check(message + cur_out, history=None)
|
| 959 |
+
if message_safety is not None:
|
| 960 |
+
raise gr.Error(message_safety)
|
| 961 |
+
|
| 962 |
+
|
| 963 |
|
| 964 |
|
| 965 |
def maybe_log_conv_file(current_time, history, message, response, **kwargs):
|
|
|
|
| 1341 |
|
| 1342 |
# performance items
|
| 1343 |
|
| 1344 |
+
def create_free_form_generation_demo():
|
| 1345 |
+
global short_model_path
|
| 1346 |
+
max_tokens = MAX_TOKENS
|
| 1347 |
+
temperature = TEMPERATURE
|
| 1348 |
+
frequence_penalty = FREQUENCE_PENALTY
|
| 1349 |
+
presence_penalty = PRESENCE_PENALTY
|
| 1350 |
+
|
| 1351 |
+
introduction = """
|
| 1352 |
+
## Free-form:
|
| 1353 |
+
Put any context string (like few-shot prompts) and get the model to generate.
|
| 1354 |
+
"""
|
| 1355 |
+
|
| 1356 |
+
with gr.Blocks() as demo_free_form:
|
| 1357 |
+
gr.Markdown(introduction)
|
| 1358 |
+
|
| 1359 |
+
with gr.Row():
|
| 1360 |
+
txt = gr.Textbox(
|
| 1361 |
+
scale=4,
|
| 1362 |
+
lines=16,
|
| 1363 |
+
show_label=False,
|
| 1364 |
+
placeholder="Enter any free form text and submit",
|
| 1365 |
+
container=False,
|
| 1366 |
+
)
|
| 1367 |
+
with gr.Row():
|
| 1368 |
+
free_submit_button = gr.Button('Submit')
|
| 1369 |
+
with gr.Row():
|
| 1370 |
+
temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
|
| 1371 |
+
length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
|
| 1372 |
+
freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
|
| 1373 |
+
pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
|
| 1374 |
+
stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
|
| 1375 |
+
|
| 1376 |
+
free_submit_button.click(
|
| 1377 |
+
generate_free_form_stream,
|
| 1378 |
+
[txt, temp, length, freq_pen, pres_pen, stop_strings],
|
| 1379 |
+
txt
|
| 1380 |
+
)
|
| 1381 |
+
return demo_free_form
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
|
| 1385 |
+
|
| 1386 |
|
| 1387 |
def launch_demo():
|
| 1388 |
global demo, llm, DEBUG, LOG_FILE
|
|
|
|
| 1485 |
gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
|
| 1486 |
gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
|
| 1487 |
gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
|
| 1488 |
+
gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
|
| 1489 |
gr.Number(value=0, label='current_time', visible=False),
|
| 1490 |
],
|
| 1491 |
outputs=[
|
|
|
|
| 1497 |
description=FILE_UPLOAD_DESCRIPTION,
|
| 1498 |
allow_flagging=False,
|
| 1499 |
examples=[
|
| 1500 |
+
["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "<s>,</s>,<|im_start|>"],
|
| 1501 |
+
["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "<s>,</s>,<|im_start|>,\\n"]
|
| 1502 |
],
|
| 1503 |
cache_examples=False,
|
| 1504 |
)
|
| 1505 |
+
|
| 1506 |
+
demo_free_form = create_free_form_generation_demo()
|
| 1507 |
|
| 1508 |
demo_chat = gr.ChatInterface(
|
| 1509 |
response_fn,
|
|
|
|
| 1539 |
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
| 1540 |
|
| 1541 |
demo = CustomTabbedInterface(
|
| 1542 |
+
interface_list=[demo_chat, demo_file_upload, demo_free_form],
|
| 1543 |
+
tab_names=["Chat Interface", "Batch Inference", "Free-form"],
|
| 1544 |
title=f"{model_title}",
|
| 1545 |
description=descriptions,
|
| 1546 |
)
|