Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,7 +25,7 @@ from tqdm.auto import tqdm
|
|
| 25 |
from huggingface_hub import snapshot_download
|
| 26 |
|
| 27 |
|
| 28 |
-
# @@
|
| 29 |
|
| 30 |
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
| 31 |
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
|
|
@@ -34,59 +34,53 @@ DTYPE = os.environ.get("DTYPE", "bfloat16")
|
|
| 34 |
|
| 35 |
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
| 36 |
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
|
|
|
|
|
|
| 37 |
# ! uploaded model path, will be downloaded to MODEL_PATH
|
| 38 |
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
|
|
|
| 39 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
|
|
| 40 |
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# gradio config
|
| 45 |
PORT = int(os.environ.get("PORT", "7860"))
|
|
|
|
| 46 |
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
| 48 |
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
|
| 49 |
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
|
|
|
|
| 50 |
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
"""
|
| 53 |
-
TODO:
|
| 54 |
-
need to upload the model as hugginface/models/seal_13b_a
|
| 55 |
-
# https://huggingface.co/docs/hub/spaces-overview#managing-secrets
|
| 56 |
-
set
|
| 57 |
-
HF_TOKEN=???
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
HF_HOME=/data/.huggingface
|
| 63 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
| 64 |
-
|
| 65 |
-
# if not persistent
|
| 66 |
MODEL_PATH=./seal-13b-chat-a
|
| 67 |
-
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
===== Application Startup at 2023-10-20 04:03:49 =====
|
| 71 |
-
|
| 72 |
-
DEBUG mode: False
|
| 73 |
-
Torch version: 2.1.0+cu121
|
| 74 |
-
Torch CUDA version: 12.1
|
| 75 |
-
/home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/cuda/__init__.py:138: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
|
| 76 |
-
return torch._C._cuda_getDeviceCount() > 0
|
| 77 |
-
Unable to obtain compute_capability: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.
|
| 78 |
-
Launch config: model_title='SeaL-13B - An Assistant for South East Asian Languages' / tensor_parallel=1 / dtype='bfloat16' / 2048 | BLOCK_ZH=True
|
| 79 |
-
| STREAM_YIELD_MULTIPLE=1
|
| 80 |
-
| frequence_penalty=0.4
|
| 81 |
-
| temperature=0.1
|
| 82 |
-
| hf_model_name=DAMO-NLP-SG/seal-13b-chat-a
|
| 83 |
-
| model_path=./seal-13b-chat-a
|
| 84 |
-
| DOWNLOAD_SNAPSHOT=True
|
| 85 |
-
sys=You are a multilingual, helpful,
|
| 86 |
|
| 87 |
"""
|
| 88 |
|
| 89 |
|
|
|
|
| 90 |
# ==============================
|
| 91 |
print(f'DEBUG mode: {DEBUG}')
|
| 92 |
print(f'Torch version: {torch.__version__}')
|
|
@@ -95,16 +89,109 @@ try:
|
|
| 95 |
except Exception as e:
|
| 96 |
print(f'Failed to print cuda version: {e}')
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
# @@ constants ================
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
def _detect_lang(text):
|
| 106 |
from langdetect import detect as detect_lang
|
| 107 |
-
from langdetect.detector import LangDetectException
|
| 108 |
dlang = None
|
| 109 |
try:
|
| 110 |
dlang = detect_lang(text)
|
|
@@ -118,11 +205,12 @@ def _detect_lang(text):
|
|
| 118 |
return dlang
|
| 119 |
|
| 120 |
|
| 121 |
-
def
|
| 122 |
model_name_or_path: str,
|
| 123 |
cache_dir: Optional[str] = None,
|
| 124 |
use_np_cache: bool = False,
|
| 125 |
) -> Iterator[Tuple[str, torch.Tensor]]:
|
|
|
|
| 126 |
from vllm.model_executor.weight_utils import Disabledtqdm
|
| 127 |
# Prepare file lock directory to prevent multiple processes from
|
| 128 |
# downloading the same model weights at the same time.
|
|
@@ -143,7 +231,6 @@ def hf_model_weights_iterator(
|
|
| 143 |
hf_folder = model_name_or_path
|
| 144 |
|
| 145 |
hf_bin_files = [
|
| 146 |
-
# x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
|
| 147 |
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
|
| 148 |
if not x.endswith("training_args.bin")
|
| 149 |
]
|
|
@@ -236,9 +323,9 @@ def llama_load_weights(
|
|
| 236 |
cache_dir: Optional[str] = None,
|
| 237 |
use_np_cache: bool = False,
|
| 238 |
load_format: str = "auto",
|
| 239 |
-
# load_format: str = "pt",
|
| 240 |
revision: Optional[str] = None
|
| 241 |
):
|
|
|
|
| 242 |
from vllm.model_executor.weight_utils import (
|
| 243 |
load_tensor_parallel_weights
|
| 244 |
)
|
|
@@ -261,7 +348,7 @@ def llama_load_weights(
|
|
| 261 |
state_dict = self.state_dict()
|
| 262 |
need_to_load = len(state_dict)
|
| 263 |
loaded = 0
|
| 264 |
-
iterator =
|
| 265 |
|
| 266 |
for name, loaded_weight in iterator:
|
| 267 |
if "rotary_emb.inv_freq" in name:
|
|
@@ -331,7 +418,6 @@ def llama_load_weights(
|
|
| 331 |
loaded_weight[v_offsets[0]:v_offsets[1]],
|
| 332 |
], 0
|
| 333 |
)
|
| 334 |
-
# print(f'{name} | {q_offsets} | {k_offsets} | {v_offsets}')
|
| 335 |
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
|
| 336 |
param.data.copy_(_loaded_weight)
|
| 337 |
loaded += 1.0
|
|
@@ -398,19 +484,158 @@ def llama_load_weights(
|
|
| 398 |
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
| 399 |
|
| 400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
| 402 |
if not DEBUG:
|
| 403 |
|
| 404 |
-
# vllm import
|
| 405 |
-
# from vllm import LLM, SamplingParams
|
| 406 |
-
# ! reconfigure vllm to faster llama
|
| 407 |
try:
|
| 408 |
import vllm
|
| 409 |
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
| 410 |
from vllm.model_executor.models import LlamaForCausalLM
|
| 411 |
|
| 412 |
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
if DTYPE == "bfloat16":
|
| 416 |
try:
|
|
@@ -433,33 +658,6 @@ if not DEBUG:
|
|
| 433 |
set_documentation_group("component")
|
| 434 |
|
| 435 |
|
| 436 |
-
|
| 437 |
-
DTYPES = {
|
| 438 |
-
'float16': torch.float16,
|
| 439 |
-
'bfloat16': torch.bfloat16
|
| 440 |
-
}
|
| 441 |
-
|
| 442 |
-
llm = None
|
| 443 |
-
demo = None
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
BOS_TOKEN = '<s>'
|
| 447 |
-
EOS_TOKEN = '</s>'
|
| 448 |
-
|
| 449 |
-
B_INST, E_INST = "[INST]", "[/INST]"
|
| 450 |
-
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 451 |
-
|
| 452 |
-
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \
|
| 453 |
-
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
| 454 |
-
that your responses are socially unbiased and positive in nature.
|
| 455 |
-
|
| 456 |
-
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
| 457 |
-
correct. If you don't know the answer to a question, please don't share false information.
|
| 458 |
-
|
| 459 |
-
As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \
|
| 460 |
-
Your response should adapt to the norms and customs of the respective language and culture.
|
| 461 |
-
"""
|
| 462 |
-
|
| 463 |
RES_PRINTED = False
|
| 464 |
|
| 465 |
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
|
|
@@ -576,8 +774,117 @@ def _setup_stop_events(
|
|
| 576 |
api_name=False,
|
| 577 |
queue=False,
|
| 578 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
|
|
|
| 581 |
|
| 582 |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
| 583 |
global llm
|
|
@@ -611,7 +918,6 @@ def vllm_abort(self: Any):
|
|
| 611 |
continue
|
| 612 |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
| 613 |
|
| 614 |
-
# def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
|
| 615 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
| 616 |
from vllm.outputs import RequestOutput
|
| 617 |
# Initialize tqdm.
|
|
@@ -624,16 +930,9 @@ def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
|
| 624 |
step_outputs = self.llm_engine.step()
|
| 625 |
for output in step_outputs:
|
| 626 |
outputs[output.request_id] = output
|
| 627 |
-
# outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
| 628 |
if len(outputs) > 0:
|
| 629 |
yield outputs
|
| 630 |
-
|
| 631 |
-
# pbar.close()
|
| 632 |
-
# Sort the outputs by request ID.
|
| 633 |
-
# This is necessary because some requests may be finished earlier than
|
| 634 |
-
# its previous requests.
|
| 635 |
-
# outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
| 636 |
-
# return outputs
|
| 637 |
|
| 638 |
|
| 639 |
def vllm_generate_stream(
|
|
@@ -692,64 +991,47 @@ def vllm_generate_stream(
|
|
| 692 |
yield from _vllm_run_engine(self, use_tqdm)
|
| 693 |
|
| 694 |
|
| 695 |
-
# def chat_response_stream(
|
| 696 |
-
# message: str,
|
| 697 |
-
# history: List[Tuple[str, str]],
|
| 698 |
-
# temperature: float,
|
| 699 |
-
# max_tokens: int,
|
| 700 |
-
# frequency_penalty: float,
|
| 701 |
-
# system_prompt: str
|
| 702 |
-
# ) -> str:
|
| 703 |
-
# global llm, RES_PRINTED
|
| 704 |
-
# assert llm is not None
|
| 705 |
-
# # force removing all
|
| 706 |
-
# vllm_abort(llm)
|
| 707 |
-
|
| 708 |
-
# temperature = float(temperature)
|
| 709 |
-
# frequency_penalty = float(frequency_penalty)
|
| 710 |
-
# max_tokens = int(max_tokens)
|
| 711 |
-
# if system_prompt.strip() != '':
|
| 712 |
-
# # chat version, add system prompt
|
| 713 |
-
# message = llama_chat_sys_input_seq_constructor(
|
| 714 |
-
# message.strip(),
|
| 715 |
-
# sys_prompt=system_prompt
|
| 716 |
-
# )
|
| 717 |
-
# sampling_params = SamplingParams(
|
| 718 |
-
# temperature=temperature, max_tokens=max_tokens,
|
| 719 |
-
# frequency_penalty=frequency_penalty,
|
| 720 |
-
# )
|
| 721 |
-
# cur_out = None
|
| 722 |
-
# for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
|
| 723 |
-
# if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
| 724 |
-
# yield cur_out
|
| 725 |
-
# assert len(gen) == 1, f'{gen}'
|
| 726 |
-
# item = next(iter(gen.values()))
|
| 727 |
-
# cur_out = item.outputs[0].text
|
| 728 |
-
# if not RES_PRINTED:
|
| 729 |
-
# print(f'{message}<<<{cur_out}>>>')
|
| 730 |
-
# RES_PRINTED = True
|
| 731 |
-
# if cur_out is not None:
|
| 732 |
-
# yield cur_out
|
| 733 |
-
|
| 734 |
-
|
| 735 |
BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation.
|
| 736 |
抱歉,目前不支持中文。 请清除聊天框以进行新对话。"""
|
| 737 |
|
|
|
|
|
|
|
| 738 |
def block_zh(
|
| 739 |
message: str,
|
| 740 |
history: List[Tuple[str, str]]
|
| 741 |
) -> str:
|
| 742 |
-
|
| 743 |
-
if any((BLOCK_MESSAGE in x[1].strip()) for x in history):
|
| 744 |
return True
|
| 745 |
elif 'zh' in _detect_lang(message):
|
| 746 |
print(f'Detect zh: {message}')
|
| 747 |
return True
|
| 748 |
-
# ! optionally detect every responses message
|
| 749 |
else:
|
| 750 |
return False
|
| 751 |
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
def chat_response_stream_multiturn(
|
| 754 |
message: str,
|
| 755 |
history: List[Tuple[str, str]],
|
|
@@ -779,44 +1061,48 @@ def chat_response_stream_multiturn(
|
|
| 779 |
|
| 780 |
message = message.strip()
|
| 781 |
|
| 782 |
-
|
| 783 |
-
|
|
|
|
|
|
|
| 784 |
|
| 785 |
-
# ! lang detect
|
| 786 |
-
if BLOCK_ZH:
|
| 787 |
-
if block_zh(message, history):
|
| 788 |
-
yield BLOCK_MESSAGE
|
| 789 |
-
return
|
| 790 |
-
|
| 791 |
-
# history.append([message, None])
|
| 792 |
# history will be appended with message later on
|
| 793 |
full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
|
| 794 |
message, history, sys_prompt=system_prompt
|
| 795 |
)
|
| 796 |
-
|
| 797 |
sampling_params = SamplingParams(
|
| 798 |
temperature=temperature, max_tokens=max_tokens,
|
| 799 |
frequency_penalty=frequency_penalty,
|
| 800 |
)
|
| 801 |
cur_out = None
|
| 802 |
-
|
| 803 |
for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
|
| 804 |
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
yield cur_out
|
| 806 |
assert len(gen) == 1, f'{gen}'
|
| 807 |
item = next(iter(gen.values()))
|
| 808 |
cur_out = item.outputs[0].text
|
| 809 |
|
| 810 |
-
|
| 811 |
-
print(f'{full_prompt}<<<{cur_out}>>>\n')
|
| 812 |
-
# RES_PRINTED = True
|
| 813 |
if cur_out is not None:
|
| 814 |
yield cur_out
|
| 815 |
|
| 816 |
-
|
| 817 |
-
if
|
| 818 |
-
|
| 819 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
|
| 821 |
|
| 822 |
def debug_chat_response_echo(
|
|
@@ -832,44 +1118,6 @@ def debug_chat_response_echo(
|
|
| 832 |
yield f"repeat: {message}"
|
| 833 |
|
| 834 |
|
| 835 |
-
# ============ CONSTANT ============
|
| 836 |
-
# https://github.com/gradio-app/gradio/issues/884
|
| 837 |
-
MODEL_NAME = "SeaL-13B"
|
| 838 |
-
MODEL_TITLE = "SeaL-13B - An Assistant for South East Asian Languages"
|
| 839 |
-
# ! add icon: "<img src='file/lion.jpg' alt='image One'>"
|
| 840 |
-
MODEL_DESC = """
|
| 841 |
-
<span style="font-size: larger">
|
| 842 |
-
This is a DAMO SeaL-13B chatbot assistant built by DAMO Academy, Alibaba Group. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
|
| 843 |
-
</span>
|
| 844 |
-
""".strip()
|
| 845 |
-
# <br>
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
cite_markdown = """
|
| 849 |
-
## Citation
|
| 850 |
-
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
| 851 |
-
```
|
| 852 |
-
@article{damonlpsg2023seallm,
|
| 853 |
-
author = {???},
|
| 854 |
-
title = {SeaL: A language model for South East Asian Languages},
|
| 855 |
-
year = 2023,
|
| 856 |
-
}
|
| 857 |
-
```
|
| 858 |
-
"""
|
| 859 |
-
|
| 860 |
-
warning_markdown = """
|
| 861 |
-
## Warning:
|
| 862 |
-
<span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
| 863 |
-
<span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
| 864 |
-
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
| 865 |
-
"""
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
path_markdown = """
|
| 869 |
-
#### Model path:
|
| 870 |
-
{model_path}
|
| 871 |
-
"""
|
| 872 |
-
|
| 873 |
def check_model_path(model_path) -> str:
|
| 874 |
assert os.path.exists(model_path), f'{model_path} not found'
|
| 875 |
ckpt_info = "None"
|
|
@@ -903,11 +1151,14 @@ def launch():
|
|
| 903 |
print(
|
| 904 |
f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
|
| 905 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
|
|
|
| 906 |
f'\n| frequence_penalty={frequence_penalty} '
|
| 907 |
f'\n| temperature={temperature} '
|
| 908 |
f'\n| hf_model_name={hf_model_name} '
|
| 909 |
f'\n| model_path={model_path} '
|
| 910 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
|
|
|
|
|
|
| 911 |
f'\nsys={SYSTEM_PROMPT_1}'
|
| 912 |
f'\ndesc={model_desc}'
|
| 913 |
)
|
|
@@ -928,13 +1179,23 @@ def launch():
|
|
| 928 |
snapshot_download(hf_model_name, local_dir=model_path)
|
| 929 |
|
| 930 |
import vllm
|
| 931 |
-
from vllm import LLM
|
| 932 |
|
| 933 |
print(F'VLLM: {vllm.__version__}')
|
| 934 |
ckpt_info = check_model_path(model_path)
|
| 935 |
|
| 936 |
print(f'Load path: {model_path} | {ckpt_info}')
|
| 937 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
|
| 939 |
print(f'Use system prompt:\n{sys_prompt}')
|
| 940 |
|
|
@@ -957,16 +1218,17 @@ def launch():
|
|
| 957 |
stop_btn=None,
|
| 958 |
title=f"{model_title}",
|
| 959 |
description=f"{model_desc}",
|
| 960 |
-
# ! decide if can change the system prompt.
|
| 961 |
additional_inputs=[
|
| 962 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
| 963 |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
| 964 |
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
|
|
|
|
| 965 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
| 966 |
],
|
| 967 |
)
|
|
|
|
| 968 |
with demo:
|
| 969 |
-
gr.Markdown(warning_markdown)
|
| 970 |
gr.Markdown(cite_markdown)
|
| 971 |
gr.Markdown(path_markdown.format(model_path=model_path))
|
| 972 |
|
|
@@ -981,30 +1243,3 @@ def main():
|
|
| 981 |
|
| 982 |
if __name__ == "__main__":
|
| 983 |
main()
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
"""
|
| 987 |
-
|
| 988 |
-
export CUDA_VISIBLE_DEVICES=0
|
| 989 |
-
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000
|
| 990 |
-
export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster
|
| 991 |
-
export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp
|
| 992 |
-
|
| 993 |
-
export DEBUG=0
|
| 994 |
-
export CUDA_VISIBLE_DEVICES=0
|
| 995 |
-
export MODEL_PATH=seal_13b_a
|
| 996 |
-
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW12k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.SeaV2Cq13M.SeaV2Cq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_6000
|
| 997 |
-
|
| 998 |
-
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/mer13s108Hi16kPretFlCWNLP12k_SFT2.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.Sft2Censor.Sft2Censor.m4k.b8.lr1e5.linear.wa0k.ms1144k.grac1.se1.6g.v4c.zfsdp/step_4000
|
| 999 |
-
# 70-30 model
|
| 1000 |
-
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/mer13s108Hi16kPretFlCWNLP12k_SFT2.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.BgSft2aCensor0a.BgSft2Cens.BgSft2Cens.m4k.b2.lr1e5.linear.wa0k.ms4577k.grac1.se1.6g.v4c73.zfsdp/step_500
|
| 1001 |
-
export PORT=8799
|
| 1002 |
-
export BLOCK_ZH=1
|
| 1003 |
-
export DEBUG=0
|
| 1004 |
-
python app.py
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
DEBUG=1 python app.py
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
"""
|
|
|
|
| 25 |
from huggingface_hub import snapshot_download
|
| 26 |
|
| 27 |
|
| 28 |
+
# @@ environments ================
|
| 29 |
|
| 30 |
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
| 31 |
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
|
|
|
|
| 34 |
|
| 35 |
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
| 36 |
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
| 37 |
+
LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
|
| 38 |
+
|
| 39 |
# ! uploaded model path, will be downloaded to MODEL_PATH
|
| 40 |
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
| 41 |
+
# ! if model is private, need HF_TOKEN to access the model
|
| 42 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 43 |
+
# ! path where the model is downloaded, either on ./ or persistent disc
|
| 44 |
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
| 45 |
|
| 46 |
+
# ! list of keywords to disabled as security measures to comply with local regulation
|
| 47 |
+
KEYWORDS = os.environ.get("KEYWORDS", "").strip()
|
| 48 |
+
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
|
| 49 |
+
KEYWORDS = [x.lower() for x in KEYWORDS]
|
| 50 |
|
| 51 |
# gradio config
|
| 52 |
PORT = int(os.environ.get("PORT", "7860"))
|
| 53 |
+
# how many iterations to yield response
|
| 54 |
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
| 55 |
+
# how many iterations to perform safety check on response
|
| 56 |
+
STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
|
| 57 |
+
|
| 58 |
+
# self explanatory
|
| 59 |
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
| 60 |
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
|
| 61 |
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
|
| 62 |
+
gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
|
| 63 |
|
| 64 |
+
# whether to enable quantization, currently not in use
|
| 65 |
+
QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
"""
|
| 69 |
+
Internal instructions of how to configure the DEMO
|
| 70 |
|
| 71 |
+
1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
|
| 72 |
+
2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
|
| 73 |
+
3. space config env: `HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a` or the underlining model
|
| 74 |
+
4. If enable persistent storage: set
|
| 75 |
HF_HOME=/data/.huggingface
|
| 76 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
| 77 |
+
if not:
|
|
|
|
| 78 |
MODEL_PATH=./seal-13b-chat-a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
"""
|
| 81 |
|
| 82 |
|
| 83 |
+
|
| 84 |
# ==============================
|
| 85 |
print(f'DEBUG mode: {DEBUG}')
|
| 86 |
print(f'Torch version: {torch.__version__}')
|
|
|
|
| 89 |
except Exception as e:
|
| 90 |
print(f'Failed to print cuda version: {e}')
|
| 91 |
|
| 92 |
+
try:
|
| 93 |
+
compute_capability = torch.cuda.get_device_capability()
|
| 94 |
+
print(f'Torch CUDA compute_capability: {compute_capability}')
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f'Failed to print compute_capability version: {e}')
|
| 97 |
|
| 98 |
|
| 99 |
# @@ constants ================
|
| 100 |
|
| 101 |
+
DTYPES = {
|
| 102 |
+
'float16': torch.float16,
|
| 103 |
+
'bfloat16': torch.bfloat16
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
llm = None
|
| 107 |
+
demo = None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
BOS_TOKEN = '<s>'
|
| 111 |
+
EOS_TOKEN = '</s>'
|
| 112 |
+
|
| 113 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 114 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 115 |
+
|
| 116 |
+
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \
|
| 117 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
| 118 |
+
that your responses are socially unbiased and positive in nature.
|
| 119 |
+
|
| 120 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
| 121 |
+
correct. If you don't know the answer to a question, please don't share false information.
|
| 122 |
+
|
| 123 |
+
As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \
|
| 124 |
+
Your response should adapt to the norms and customs of the respective language and culture.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
# ============ CONSTANT ============
|
| 128 |
+
# https://github.com/gradio-app/gradio/issues/884
|
| 129 |
+
MODEL_NAME = "SeaLLM-13B"
|
| 130 |
+
MODEL_TITLE = "SeaLLM-13B - An Assistant for South East Asian Languages"
|
| 131 |
+
# ! add icon: "<img src='file/lion.jpg' alt='image One'>"
|
| 132 |
+
MODEL_TITLE = """
|
| 133 |
+
<div class="container" style="
|
| 134 |
+
align-items: center;
|
| 135 |
+
justify-content: center;
|
| 136 |
+
display: flex;
|
| 137 |
+
">
|
| 138 |
+
<div class="image" >
|
| 139 |
+
<img src="file/seal_logo.png" style="
|
| 140 |
+
max-width: 10em;
|
| 141 |
+
max-height: 5%;
|
| 142 |
+
height: 5em;
|
| 143 |
+
width: 5em;
|
| 144 |
+
float: left;
|
| 145 |
+
margin-left: auto;
|
| 146 |
+
">
|
| 147 |
+
</div>
|
| 148 |
+
<div class="text" style="
|
| 149 |
+
padding-left: 20px;
|
| 150 |
+
padding-top: 2%;
|
| 151 |
+
float: left;
|
| 152 |
+
">
|
| 153 |
+
<h1>SeaLLM-13B - An Assistant for South East Asian Languages</h1>
|
| 154 |
+
</div>
|
| 155 |
+
</div>
|
| 156 |
+
"""
|
| 157 |
+
MODEL_DESC = """
|
| 158 |
+
<span style="font-size: larger">
|
| 159 |
+
This is SeaLLM-13B - a chatbot assistant optimized for South East Asian Languages. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
|
| 160 |
+
</span>
|
| 161 |
+
<br>
|
| 162 |
+
<span style="color: red">NOTICE: The chatbot may produce inaccurate and harmful information about people, places, or facts. \
|
| 163 |
+
We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
| 164 |
+
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
| 165 |
+
""".strip()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
cite_markdown = """
|
| 169 |
+
## Citation
|
| 170 |
+
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
| 171 |
+
```
|
| 172 |
+
@article{damonlpsg2023seallm,
|
| 173 |
+
author = {???},
|
| 174 |
+
title = {SeaLLM: A language model for South East Asian Languages},
|
| 175 |
+
year = 2023,
|
| 176 |
+
}
|
| 177 |
+
```
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
# warning_markdown = """
|
| 181 |
+
# ## Warning:
|
| 182 |
+
# <span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
| 183 |
+
# <span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
| 184 |
+
# or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
| 185 |
+
# """
|
| 186 |
+
|
| 187 |
+
path_markdown = """
|
| 188 |
+
#### Model path:
|
| 189 |
+
{model_path}
|
| 190 |
+
"""
|
| 191 |
|
| 192 |
|
| 193 |
def _detect_lang(text):
|
| 194 |
from langdetect import detect as detect_lang
|
|
|
|
| 195 |
dlang = None
|
| 196 |
try:
|
| 197 |
dlang = detect_lang(text)
|
|
|
|
| 205 |
return dlang
|
| 206 |
|
| 207 |
|
| 208 |
+
def custom_hf_model_weights_iterator(
|
| 209 |
model_name_or_path: str,
|
| 210 |
cache_dir: Optional[str] = None,
|
| 211 |
use_np_cache: bool = False,
|
| 212 |
) -> Iterator[Tuple[str, torch.Tensor]]:
|
| 213 |
+
# ! if use vllm==0.1.4, use this to augment hf_model_weights_iterator loader
|
| 214 |
from vllm.model_executor.weight_utils import Disabledtqdm
|
| 215 |
# Prepare file lock directory to prevent multiple processes from
|
| 216 |
# downloading the same model weights at the same time.
|
|
|
|
| 231 |
hf_folder = model_name_or_path
|
| 232 |
|
| 233 |
hf_bin_files = [
|
|
|
|
| 234 |
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
|
| 235 |
if not x.endswith("training_args.bin")
|
| 236 |
]
|
|
|
|
| 323 |
cache_dir: Optional[str] = None,
|
| 324 |
use_np_cache: bool = False,
|
| 325 |
load_format: str = "auto",
|
|
|
|
| 326 |
revision: Optional[str] = None
|
| 327 |
):
|
| 328 |
+
# if use vllm==0.1.4
|
| 329 |
from vllm.model_executor.weight_utils import (
|
| 330 |
load_tensor_parallel_weights
|
| 331 |
)
|
|
|
|
| 348 |
state_dict = self.state_dict()
|
| 349 |
need_to_load = len(state_dict)
|
| 350 |
loaded = 0
|
| 351 |
+
iterator = custom_hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
| 352 |
|
| 353 |
for name, loaded_weight in iterator:
|
| 354 |
if "rotary_emb.inv_freq" in name:
|
|
|
|
| 418 |
loaded_weight[v_offsets[0]:v_offsets[1]],
|
| 419 |
], 0
|
| 420 |
)
|
|
|
|
| 421 |
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
|
| 422 |
param.data.copy_(_loaded_weight)
|
| 423 |
loaded += 1.0
|
|
|
|
| 484 |
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
| 485 |
|
| 486 |
|
| 487 |
+
def new_llama_load_weights(
|
| 488 |
+
self,
|
| 489 |
+
model_name_or_path: str,
|
| 490 |
+
cache_dir: Optional[str] = None,
|
| 491 |
+
load_format: str = "auto",
|
| 492 |
+
revision: Optional[str] = None
|
| 493 |
+
):
|
| 494 |
+
# If use newest vllm
|
| 495 |
+
from vllm.model_executor.weight_utils import (
|
| 496 |
+
load_tensor_parallel_weights, hf_model_weights_iterator
|
| 497 |
+
)
|
| 498 |
+
from vllm.model_executor.parallel_utils.parallel_state import (
|
| 499 |
+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
| 500 |
+
|
| 501 |
+
if self.quant_config is None:
|
| 502 |
+
weight_suffixes = ["weight"]
|
| 503 |
+
else:
|
| 504 |
+
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
| 505 |
+
|
| 506 |
+
column_parallel_weights: List[str] = []
|
| 507 |
+
for layer in self._column_parallel_layers:
|
| 508 |
+
for suffix in weight_suffixes:
|
| 509 |
+
column_parallel_weights.append(f"{layer}.{suffix}")
|
| 510 |
+
row_parallel_weights: List[str] = []
|
| 511 |
+
for layer in self._row_parallel_layers:
|
| 512 |
+
for suffix in weight_suffixes:
|
| 513 |
+
row_parallel_weights.append(f"{layer}.{suffix}")
|
| 514 |
+
|
| 515 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 516 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 517 |
+
assert tp_size == 1, f'tensorparallel >=2 not allowed. {tp_size}'
|
| 518 |
+
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
| 519 |
+
num_kv_heads_replicas = max(1,
|
| 520 |
+
tp_size // self.config.num_key_value_heads)
|
| 521 |
+
num_kv_heads_per_gpu = max(1,
|
| 522 |
+
self.config.num_key_value_heads // tp_size)
|
| 523 |
+
kv_proj_shard_size = (self.config.hidden_size //
|
| 524 |
+
self.config.num_attention_heads *
|
| 525 |
+
num_kv_heads_per_gpu)
|
| 526 |
+
attention_weight_specs = [
|
| 527 |
+
# (weight_name, shard_size, offset)
|
| 528 |
+
("q_proj", q_proj_shard_size, 0),
|
| 529 |
+
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
| 530 |
+
("v_proj", kv_proj_shard_size,
|
| 531 |
+
q_proj_shard_size + kv_proj_shard_size),
|
| 532 |
+
]
|
| 533 |
+
state_dict = self.state_dict()
|
| 534 |
+
need_to_load = len(state_dict)
|
| 535 |
+
loaded = 0
|
| 536 |
+
|
| 537 |
+
for name, loaded_weight in hf_model_weights_iterator(
|
| 538 |
+
model_name_or_path, cache_dir, load_format, revision):
|
| 539 |
+
if "rotary_emb.inv_freq" in name:
|
| 540 |
+
continue
|
| 541 |
+
|
| 542 |
+
is_packed = False
|
| 543 |
+
is_transposed = False
|
| 544 |
+
if self.quant_config is not None:
|
| 545 |
+
is_packed = self.quant_config.is_packed(name)
|
| 546 |
+
is_transposed = self.quant_config.is_transposed(name)
|
| 547 |
+
if is_transposed:
|
| 548 |
+
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
| 549 |
+
loaded_weight = loaded_weight.T
|
| 550 |
+
|
| 551 |
+
is_attention_weight = False
|
| 552 |
+
for weight_name, shard_size, offset in attention_weight_specs:
|
| 553 |
+
if weight_name not in name or "qkv_proj" in name:
|
| 554 |
+
continue
|
| 555 |
+
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
| 556 |
+
if is_transposed:
|
| 557 |
+
param = param.T
|
| 558 |
+
|
| 559 |
+
if is_packed:
|
| 560 |
+
shard_size //= self.quant_config.pack_factor
|
| 561 |
+
offset //= self.quant_config.pack_factor
|
| 562 |
+
|
| 563 |
+
if weight_name in ["k_proj", "v_proj"]:
|
| 564 |
+
shard_id = tp_rank // num_kv_heads_replicas
|
| 565 |
+
else:
|
| 566 |
+
shard_id = tp_rank
|
| 567 |
+
loaded_weight = loaded_weight[shard_size *
|
| 568 |
+
shard_id:shard_size *
|
| 569 |
+
(shard_id + 1)]
|
| 570 |
+
param_slice = param.data[offset:offset + shard_size]
|
| 571 |
+
assert param_slice.shape == loaded_weight.shape
|
| 572 |
+
|
| 573 |
+
param_slice.copy_(loaded_weight)
|
| 574 |
+
loaded += 1.0 / 3
|
| 575 |
+
is_attention_weight = True
|
| 576 |
+
break
|
| 577 |
+
if is_attention_weight:
|
| 578 |
+
continue
|
| 579 |
+
|
| 580 |
+
# TODO: need to figure out to do sharding with qkv_proj fused
|
| 581 |
+
|
| 582 |
+
is_gate_up_weight = False
|
| 583 |
+
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
| 584 |
+
if weight_name not in name or "gate_up_proj" in name:
|
| 585 |
+
continue
|
| 586 |
+
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
| 587 |
+
if is_transposed:
|
| 588 |
+
param = param.T
|
| 589 |
+
|
| 590 |
+
shard_size = param.shape[0] // 2
|
| 591 |
+
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
| 592 |
+
(tp_rank + 1)]
|
| 593 |
+
param_slice = param.data[shard_size * stride_id:shard_size *
|
| 594 |
+
(stride_id + 1)]
|
| 595 |
+
assert param_slice.shape == loaded_weight.shape
|
| 596 |
+
param_slice.copy_(loaded_weight)
|
| 597 |
+
loaded += 1.0 / 2
|
| 598 |
+
is_gate_up_weight = True
|
| 599 |
+
break
|
| 600 |
+
if is_gate_up_weight:
|
| 601 |
+
continue
|
| 602 |
+
|
| 603 |
+
# TODO: need to figure out to do sharding with gate_up_proj fused
|
| 604 |
+
|
| 605 |
+
param = state_dict[name]
|
| 606 |
+
if is_transposed:
|
| 607 |
+
param = param.T
|
| 608 |
+
|
| 609 |
+
if "embed_tokens" in name or "lm_head" in name:
|
| 610 |
+
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
| 611 |
+
tp_rank)
|
| 612 |
+
loaded += 1
|
| 613 |
+
continue
|
| 614 |
+
|
| 615 |
+
load_tensor_parallel_weights(param, loaded_weight, name,
|
| 616 |
+
column_parallel_weights,
|
| 617 |
+
row_parallel_weights, tp_rank)
|
| 618 |
+
loaded += 1
|
| 619 |
+
|
| 620 |
+
if np.abs(loaded - need_to_load) < 0.01:
|
| 621 |
+
print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
|
| 622 |
+
else:
|
| 623 |
+
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
| 624 |
+
|
| 625 |
+
|
| 626 |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
| 627 |
if not DEBUG:
|
| 628 |
|
|
|
|
|
|
|
|
|
|
| 629 |
try:
|
| 630 |
import vllm
|
| 631 |
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
| 632 |
from vllm.model_executor.models import LlamaForCausalLM
|
| 633 |
|
| 634 |
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
| 635 |
+
if vllm.__version__ == "0.1.4":
|
| 636 |
+
LlamaForCausalLM.load_weights = llama_load_weights
|
| 637 |
+
else:
|
| 638 |
+
LlamaForCausalLM.load_weights = new_llama_load_weights
|
| 639 |
|
| 640 |
if DTYPE == "bfloat16":
|
| 641 |
try:
|
|
|
|
| 658 |
set_documentation_group("component")
|
| 659 |
|
| 660 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
RES_PRINTED = False
|
| 662 |
|
| 663 |
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
|
|
|
|
| 774 |
api_name=False,
|
| 775 |
queue=False,
|
| 776 |
)
|
| 777 |
+
# upon clear, cancel the submit event as well
|
| 778 |
+
if self.clear_btn:
|
| 779 |
+
self.clear_btn.click(
|
| 780 |
+
lambda: ([], [], None, Button.update(interactive=True)),
|
| 781 |
+
None,
|
| 782 |
+
[self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
|
| 783 |
+
queue=False,
|
| 784 |
+
api_name=False,
|
| 785 |
+
cancels=event_to_cancel,
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
# TODO: reconfigure clear button as stop and clear button
|
| 789 |
+
def _setup_events(self) -> None:
|
| 790 |
+
has_on = False
|
| 791 |
+
try:
|
| 792 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
| 793 |
+
has_on = True
|
| 794 |
+
except ImportError as ie:
|
| 795 |
+
has_on = False
|
| 796 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
if has_on:
|
| 800 |
+
# new version
|
| 801 |
+
submit_triggers = (
|
| 802 |
+
[self.textbox.submit, self.submit_btn.click]
|
| 803 |
+
if self.submit_btn
|
| 804 |
+
else [self.textbox.submit]
|
| 805 |
+
)
|
| 806 |
+
submit_event = (
|
| 807 |
+
on(
|
| 808 |
+
submit_triggers,
|
| 809 |
+
self._clear_and_save_textbox,
|
| 810 |
+
[self.textbox],
|
| 811 |
+
[self.textbox, self.saved_input],
|
| 812 |
+
api_name=False,
|
| 813 |
+
queue=False,
|
| 814 |
+
)
|
| 815 |
+
.then(
|
| 816 |
+
self._display_input,
|
| 817 |
+
[self.saved_input, self.chatbot_state],
|
| 818 |
+
[self.chatbot, self.chatbot_state],
|
| 819 |
+
api_name=False,
|
| 820 |
+
queue=False,
|
| 821 |
+
)
|
| 822 |
+
.then(
|
| 823 |
+
submit_fn,
|
| 824 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
| 825 |
+
[self.chatbot, self.chatbot_state],
|
| 826 |
+
api_name=False,
|
| 827 |
+
)
|
| 828 |
+
)
|
| 829 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
| 830 |
+
else:
|
| 831 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
| 832 |
+
|
| 833 |
+
if self.retry_btn:
|
| 834 |
+
retry_event = (
|
| 835 |
+
self.retry_btn.click(
|
| 836 |
+
self._delete_prev_fn,
|
| 837 |
+
[self.chatbot_state],
|
| 838 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
| 839 |
+
api_name=False,
|
| 840 |
+
queue=False,
|
| 841 |
+
)
|
| 842 |
+
.then(
|
| 843 |
+
self._display_input,
|
| 844 |
+
[self.saved_input, self.chatbot_state],
|
| 845 |
+
[self.chatbot, self.chatbot_state],
|
| 846 |
+
api_name=False,
|
| 847 |
+
queue=False,
|
| 848 |
+
)
|
| 849 |
+
.then(
|
| 850 |
+
submit_fn,
|
| 851 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
| 852 |
+
[self.chatbot, self.chatbot_state],
|
| 853 |
+
api_name=False,
|
| 854 |
+
)
|
| 855 |
+
)
|
| 856 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
| 857 |
+
|
| 858 |
+
if self.undo_btn:
|
| 859 |
+
self.undo_btn.click(
|
| 860 |
+
self._delete_prev_fn,
|
| 861 |
+
[self.chatbot_state],
|
| 862 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
| 863 |
+
api_name=False,
|
| 864 |
+
queue=False,
|
| 865 |
+
).then(
|
| 866 |
+
lambda x: x,
|
| 867 |
+
[self.saved_input],
|
| 868 |
+
[self.textbox],
|
| 869 |
+
api_name=False,
|
| 870 |
+
queue=False,
|
| 871 |
+
)
|
| 872 |
|
| 873 |
+
# Reconfigure clear_btn to stop and clear text box
|
| 874 |
+
# if self.clear_btn:
|
| 875 |
+
# self.clear_btn.click(
|
| 876 |
+
# lambda: ([], [], None),
|
| 877 |
+
# None,
|
| 878 |
+
# [self.chatbot, self.chatbot_state, self.saved_input],
|
| 879 |
+
# queue=False,
|
| 880 |
+
# api_name=False,
|
| 881 |
+
# cancels=submit_event,
|
| 882 |
+
# )
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
# replace
|
| 886 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
| 887 |
+
gr.ChatInterface._setup_events = _setup_events
|
| 888 |
|
| 889 |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
| 890 |
global llm
|
|
|
|
| 918 |
continue
|
| 919 |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
| 920 |
|
|
|
|
| 921 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
| 922 |
from vllm.outputs import RequestOutput
|
| 923 |
# Initialize tqdm.
|
|
|
|
| 930 |
step_outputs = self.llm_engine.step()
|
| 931 |
for output in step_outputs:
|
| 932 |
outputs[output.request_id] = output
|
|
|
|
| 933 |
if len(outputs) > 0:
|
| 934 |
yield outputs
|
| 935 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 936 |
|
| 937 |
|
| 938 |
def vllm_generate_stream(
|
|
|
|
| 991 |
yield from _vllm_run_engine(self, use_tqdm)
|
| 992 |
|
| 993 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 994 |
BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation.
|
| 995 |
抱歉,目前不支持中文。 请清除聊天框以进行新对话。"""
|
| 996 |
|
| 997 |
+
KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated questions, I'll be glad to help."
|
| 998 |
+
|
| 999 |
def block_zh(
|
| 1000 |
message: str,
|
| 1001 |
history: List[Tuple[str, str]]
|
| 1002 |
) -> str:
|
| 1003 |
+
if history is not None and any((BLOCK_MESSAGE in x[1].strip()) for x in history):
|
|
|
|
| 1004 |
return True
|
| 1005 |
elif 'zh' in _detect_lang(message):
|
| 1006 |
print(f'Detect zh: {message}')
|
| 1007 |
return True
|
|
|
|
| 1008 |
else:
|
| 1009 |
return False
|
| 1010 |
|
| 1011 |
+
|
| 1012 |
+
def log_responses(history, message, response):
|
| 1013 |
+
pass
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
def safety_check(text, history=None, ) -> Optional[str]:
|
| 1017 |
+
"""
|
| 1018 |
+
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
| 1019 |
+
This provides an additional security measure to enhance safety and compliance with local regulations.
|
| 1020 |
+
"""
|
| 1021 |
+
if BLOCK_ZH:
|
| 1022 |
+
if history is not None:
|
| 1023 |
+
if block_zh(text, history):
|
| 1024 |
+
return BLOCK_MESSAGE
|
| 1025 |
+
else:
|
| 1026 |
+
if "zh" in _detect_lang(text):
|
| 1027 |
+
return BLOCK_MESSAGE
|
| 1028 |
+
|
| 1029 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
| 1030 |
+
return KEYWORD_BLOCK_MESSAGE
|
| 1031 |
+
|
| 1032 |
+
return None
|
| 1033 |
+
|
| 1034 |
+
|
| 1035 |
def chat_response_stream_multiturn(
|
| 1036 |
message: str,
|
| 1037 |
history: List[Tuple[str, str]],
|
|
|
|
| 1061 |
|
| 1062 |
message = message.strip()
|
| 1063 |
|
| 1064 |
+
message_safety = safety_check(message, history=history)
|
| 1065 |
+
if message_safety is not None:
|
| 1066 |
+
yield message_safety
|
| 1067 |
+
return
|
| 1068 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1069 |
# history will be appended with message later on
|
| 1070 |
full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
|
| 1071 |
message, history, sys_prompt=system_prompt
|
| 1072 |
)
|
| 1073 |
+
|
| 1074 |
sampling_params = SamplingParams(
|
| 1075 |
temperature=temperature, max_tokens=max_tokens,
|
| 1076 |
frequency_penalty=frequency_penalty,
|
| 1077 |
)
|
| 1078 |
cur_out = None
|
| 1079 |
+
|
| 1080 |
for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
|
| 1081 |
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
| 1082 |
+
# optionally check safety, and respond
|
| 1083 |
+
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
|
| 1084 |
+
message_safety = safety_check(cur_out, history=None)
|
| 1085 |
+
if message_safety is not None:
|
| 1086 |
+
yield message_safety
|
| 1087 |
+
return
|
| 1088 |
+
|
| 1089 |
yield cur_out
|
| 1090 |
assert len(gen) == 1, f'{gen}'
|
| 1091 |
item = next(iter(gen.values()))
|
| 1092 |
cur_out = item.outputs[0].text
|
| 1093 |
|
| 1094 |
+
print(f'{full_prompt}<<<{cur_out}>>>\n\n')
|
|
|
|
|
|
|
| 1095 |
if cur_out is not None:
|
| 1096 |
yield cur_out
|
| 1097 |
|
| 1098 |
+
message_safety = safety_check(cur_out, history=None)
|
| 1099 |
+
if message_safety is not None:
|
| 1100 |
+
yield message_safety
|
| 1101 |
+
return
|
| 1102 |
+
|
| 1103 |
+
if LOG_RESPONSE:
|
| 1104 |
+
log_responses(history, message, cur_out)
|
| 1105 |
+
|
| 1106 |
|
| 1107 |
|
| 1108 |
def debug_chat_response_echo(
|
|
|
|
| 1118 |
yield f"repeat: {message}"
|
| 1119 |
|
| 1120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1121 |
def check_model_path(model_path) -> str:
|
| 1122 |
assert os.path.exists(model_path), f'{model_path} not found'
|
| 1123 |
ckpt_info = "None"
|
|
|
|
| 1151 |
print(
|
| 1152 |
f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
|
| 1153 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
| 1154 |
+
f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
|
| 1155 |
f'\n| frequence_penalty={frequence_penalty} '
|
| 1156 |
f'\n| temperature={temperature} '
|
| 1157 |
f'\n| hf_model_name={hf_model_name} '
|
| 1158 |
f'\n| model_path={model_path} '
|
| 1159 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
| 1160 |
+
f'\n| gpu_memory_utilization={gpu_memory_utilization} '
|
| 1161 |
+
f'\n| KEYWORDS={KEYWORDS} '
|
| 1162 |
f'\nsys={SYSTEM_PROMPT_1}'
|
| 1163 |
f'\ndesc={model_desc}'
|
| 1164 |
)
|
|
|
|
| 1179 |
snapshot_download(hf_model_name, local_dir=model_path)
|
| 1180 |
|
| 1181 |
import vllm
|
| 1182 |
+
from vllm import LLM
|
| 1183 |
|
| 1184 |
print(F'VLLM: {vllm.__version__}')
|
| 1185 |
ckpt_info = check_model_path(model_path)
|
| 1186 |
|
| 1187 |
print(f'Load path: {model_path} | {ckpt_info}')
|
| 1188 |
+
|
| 1189 |
+
if QUANTIZATION == 'awq':
|
| 1190 |
+
print(F'Load model in int4 quantization')
|
| 1191 |
+
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, quantization="awq")
|
| 1192 |
+
else:
|
| 1193 |
+
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization)
|
| 1194 |
+
|
| 1195 |
+
try:
|
| 1196 |
+
print(llm.llm_engine.workers[0].model)
|
| 1197 |
+
except Exception as e:
|
| 1198 |
+
print(f'Cannot print model worker: {e}')
|
| 1199 |
|
| 1200 |
print(f'Use system prompt:\n{sys_prompt}')
|
| 1201 |
|
|
|
|
| 1218 |
stop_btn=None,
|
| 1219 |
title=f"{model_title}",
|
| 1220 |
description=f"{model_desc}",
|
|
|
|
| 1221 |
additional_inputs=[
|
| 1222 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
| 1223 |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
| 1224 |
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
|
| 1225 |
+
# ! Remove the system prompt textbox to avoid jailbreaking
|
| 1226 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
| 1227 |
],
|
| 1228 |
)
|
| 1229 |
+
demo.title = MODEL_NAME
|
| 1230 |
with demo:
|
| 1231 |
+
# gr.Markdown(warning_markdown)
|
| 1232 |
gr.Markdown(cite_markdown)
|
| 1233 |
gr.Markdown(path_markdown.format(model_path=model_path))
|
| 1234 |
|
|
|
|
| 1243 |
|
| 1244 |
if __name__ == "__main__":
|
| 1245 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|