support Qwen chat model
Browse files- app_modules/llm_loader.py +39 -10
- requirements.txt +1 -0
- server.py +3 -0
app_modules/llm_loader.py
CHANGED
|
@@ -207,6 +207,7 @@ class LLMLoader:
|
|
| 207 |
0.01
|
| 208 |
if "gpt4all-j" in MODEL_NAME_OR_PATH
|
| 209 |
or "dolly" in MODEL_NAME_OR_PATH
|
|
|
|
| 210 |
else 0
|
| 211 |
)
|
| 212 |
use_fast = (
|
|
@@ -216,11 +217,29 @@ class LLMLoader:
|
|
| 216 |
)
|
| 217 |
padding_side = "left" # if "dolly" in MODEL_NAME_OR_PATH else None
|
| 218 |
|
| 219 |
-
config =
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
)
|
|
|
|
| 224 |
# config.attn_config["attn_impl"] = "triton"
|
| 225 |
# config.max_seq_len = 4096
|
| 226 |
config.init_device = hf_pipeline_device_type
|
|
@@ -360,16 +379,26 @@ class LLMLoader:
|
|
| 360 |
config=config,
|
| 361 |
trust_remote_code=True,
|
| 362 |
)
|
| 363 |
-
if
|
| 364 |
-
else
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
)
|
| 370 |
)
|
| 371 |
)
|
| 372 |
print(f"Model memory footprint: {model.get_memory_footprint()}")
|
|
|
|
|
|
|
| 373 |
else:
|
| 374 |
model = MODEL_NAME_OR_PATH
|
| 375 |
|
|
|
|
| 207 |
0.01
|
| 208 |
if "gpt4all-j" in MODEL_NAME_OR_PATH
|
| 209 |
or "dolly" in MODEL_NAME_OR_PATH
|
| 210 |
+
or "Qwen" in MODEL_NAME_OR_PATH
|
| 211 |
else 0
|
| 212 |
)
|
| 213 |
use_fast = (
|
|
|
|
| 217 |
)
|
| 218 |
padding_side = "left" # if "dolly" in MODEL_NAME_OR_PATH else None
|
| 219 |
|
| 220 |
+
config = (
|
| 221 |
+
AutoConfig.from_pretrained(
|
| 222 |
+
MODEL_NAME_OR_PATH,
|
| 223 |
+
trust_remote_code=True,
|
| 224 |
+
token=token,
|
| 225 |
+
fp32=hf_pipeline_device_type == "cpu",
|
| 226 |
+
bf16=(
|
| 227 |
+
hf_pipeline_device_type != "cpu"
|
| 228 |
+
and torch_dtype == torch.bfloat16
|
| 229 |
+
),
|
| 230 |
+
fp16=(
|
| 231 |
+
hf_pipeline_device_type != "cpu"
|
| 232 |
+
and torch_dtype != torch.bfloat16
|
| 233 |
+
),
|
| 234 |
+
)
|
| 235 |
+
if "Qwen" in MODEL_NAME_OR_PATH
|
| 236 |
+
else AutoConfig.from_pretrained(
|
| 237 |
+
MODEL_NAME_OR_PATH,
|
| 238 |
+
trust_remote_code=True,
|
| 239 |
+
token=token,
|
| 240 |
+
)
|
| 241 |
)
|
| 242 |
+
|
| 243 |
# config.attn_config["attn_impl"] = "triton"
|
| 244 |
# config.max_seq_len = 4096
|
| 245 |
config.init_device = hf_pipeline_device_type
|
|
|
|
| 379 |
config=config,
|
| 380 |
trust_remote_code=True,
|
| 381 |
)
|
| 382 |
+
if "Qwen" in MODEL_NAME_OR_PATH
|
| 383 |
+
else (
|
| 384 |
+
AutoModelForCausalLM.from_pretrained(
|
| 385 |
+
MODEL_NAME_OR_PATH,
|
| 386 |
+
config=config,
|
| 387 |
+
trust_remote_code=True,
|
| 388 |
+
)
|
| 389 |
+
if token is None
|
| 390 |
+
else AutoModelForCausalLM.from_pretrained(
|
| 391 |
+
MODEL_NAME_OR_PATH,
|
| 392 |
+
config=config,
|
| 393 |
+
trust_remote_code=True,
|
| 394 |
+
token=token,
|
| 395 |
+
)
|
| 396 |
)
|
| 397 |
)
|
| 398 |
)
|
| 399 |
print(f"Model memory footprint: {model.get_memory_footprint()}")
|
| 400 |
+
model = model.eval()
|
| 401 |
+
# print(f"Model memory footprint: {model.get_memory_footprint()}")
|
| 402 |
else:
|
| 403 |
model = MODEL_NAME_OR_PATH
|
| 404 |
|
requirements.txt
CHANGED
|
@@ -32,3 +32,4 @@ gevent
|
|
| 32 |
pydantic >= 1.10.11
|
| 33 |
pypdf
|
| 34 |
python-telegram-bot
|
|
|
|
|
|
| 32 |
pydantic >= 1.10.11
|
| 33 |
pypdf
|
| 34 |
python-telegram-bot
|
| 35 |
+
transformers_stream_generator
|
server.py
CHANGED
|
@@ -86,6 +86,9 @@ if __name__ == "__main__":
|
|
| 86 |
chat_start = timer()
|
| 87 |
chat_sync("What's generative AI?", chat_id="test_user")
|
| 88 |
chat_sync("more on finance", chat_id="test_user")
|
|
|
|
|
|
|
|
|
|
| 89 |
chat_end = timer()
|
| 90 |
total_time = chat_end - chat_start
|
| 91 |
print(f"Total time used: {total_time:.3f} s")
|
|
|
|
| 86 |
chat_start = timer()
|
| 87 |
chat_sync("What's generative AI?", chat_id="test_user")
|
| 88 |
chat_sync("more on finance", chat_id="test_user")
|
| 89 |
+
# chat_sync("给我讲一个年轻人奋斗创业最终取得成功的故事。", chat_id="test_user")
|
| 90 |
+
# chat_sync("给这个故事起一个标题", chat_id="test_user")
|
| 91 |
+
# chat_sync("Write the game 'snake' in python", chat_id="test_user")
|
| 92 |
chat_end = timer()
|
| 93 |
total_time = chat_end - chat_start
|
| 94 |
print(f"Total time used: {total_time:.3f} s")
|