Sync from GitHub 233d3a7
Browse files- .env.example +7 -0
- main.py +23 -1
- requirements.txt +1 -0
.env.example
CHANGED
|
@@ -16,6 +16,13 @@ MAX_VIDEO_FRAMES=16
|
|
| 16 |
# Transformers loading hints
|
| 17 |
DEVICE_MAP=auto
|
| 18 |
TORCH_DTYPE=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# Persistent SSE session store (SQLite)
|
| 20 |
# Enable to persist streaming chunks per session_id and allow resume after server restarts.
|
| 21 |
# 1=true, 0=false
|
|
|
|
| 16 |
# Transformers loading hints
|
| 17 |
DEVICE_MAP=auto
|
| 18 |
TORCH_DTYPE=auto
|
| 19 |
+
|
| 20 |
+
# Quantization config (BitsAndBytes 4-bit)
|
| 21 |
+
# Enable 4-bit quantization to reduce VRAM usage (~5GB -> ~1.5GB)
|
| 22 |
+
LOAD_IN_4BIT=1
|
| 23 |
+
BNB_4BIT_COMPUTE_DTYPE=float16
|
| 24 |
+
BNB_4BIT_USE_DOUBLE_QUANT=1
|
| 25 |
+
BNB_4BIT_QUANT_TYPE=nf4
|
| 26 |
# Persistent SSE session store (SQLite)
|
| 27 |
# Enable to persist streaming chunks per session_id and allow resume after server restarts.
|
| 28 |
# 1=true, 0=false
|
main.py
CHANGED
|
@@ -79,6 +79,12 @@ MAX_VIDEO_FRAMES = int(os.getenv("MAX_VIDEO_FRAMES", "16"))
|
|
| 79 |
DEVICE_MAP = os.getenv("DEVICE_MAP", "auto")
|
| 80 |
TORCH_DTYPE = os.getenv("TORCH_DTYPE", "auto")
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
# Persistent session store (SQLite)
|
| 83 |
PERSIST_SESSIONS = str(os.getenv("PERSIST_SESSIONS", "0")).lower() in ("1", "true", "yes", "y")
|
| 84 |
SESSIONS_DB_PATH = os.getenv("SESSIONS_DB_PATH", "sessions.db")
|
|
@@ -426,7 +432,7 @@ class CancelResponse(BaseModel):
|
|
| 426 |
class Engine:
|
| 427 |
def __init__(self, model_id: str, hf_token: Optional[str] = None):
|
| 428 |
# Lazy import heavy deps
|
| 429 |
-
from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel
|
| 430 |
# AutoModelForImageTextToText is the v5+ replacement for Vision2Seq in Transformers
|
| 431 |
try:
|
| 432 |
from transformers import AutoModelForImageTextToText # type: ignore
|
|
@@ -442,6 +448,22 @@ class Engine:
|
|
| 442 |
# Only pass 'token' (use_auth_token is deprecated and causes conflicts)
|
| 443 |
model_kwargs["token"] = hf_token
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
# Device and dtype resolution
|
| 446 |
try:
|
| 447 |
import torch # local import to avoid heavy import at module load
|
|
|
|
| 79 |
DEVICE_MAP = os.getenv("DEVICE_MAP", "auto")
|
| 80 |
TORCH_DTYPE = os.getenv("TORCH_DTYPE", "auto")
|
| 81 |
|
| 82 |
+
# Quantization config (BitsAndBytes)
|
| 83 |
+
LOAD_IN_4BIT = str(os.getenv("LOAD_IN_4BIT", "1")).lower() in ("1", "true", "yes", "y")
|
| 84 |
+
BNB_4BIT_COMPUTE_DTYPE = os.getenv("BNB_4BIT_COMPUTE_DTYPE", "float16")
|
| 85 |
+
BNB_4BIT_USE_DOUBLE_QUANT = str(os.getenv("BNB_4BIT_USE_DOUBLE_QUANT", "1")).lower() in ("1", "true", "yes", "y")
|
| 86 |
+
BNB_4BIT_QUANT_TYPE = os.getenv("BNB_4BIT_QUANT_TYPE", "nf4")
|
| 87 |
+
|
| 88 |
# Persistent session store (SQLite)
|
| 89 |
PERSIST_SESSIONS = str(os.getenv("PERSIST_SESSIONS", "0")).lower() in ("1", "true", "yes", "y")
|
| 90 |
SESSIONS_DB_PATH = os.getenv("SESSIONS_DB_PATH", "sessions.db")
|
|
|
|
| 432 |
class Engine:
|
| 433 |
def __init__(self, model_id: str, hf_token: Optional[str] = None):
|
| 434 |
# Lazy import heavy deps
|
| 435 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel, BitsAndBytesConfig
|
| 436 |
# AutoModelForImageTextToText is the v5+ replacement for Vision2Seq in Transformers
|
| 437 |
try:
|
| 438 |
from transformers import AutoModelForImageTextToText # type: ignore
|
|
|
|
| 448 |
# Only pass 'token' (use_auth_token is deprecated and causes conflicts)
|
| 449 |
model_kwargs["token"] = hf_token
|
| 450 |
|
| 451 |
+
# Add quantization config if enabled
|
| 452 |
+
if LOAD_IN_4BIT:
|
| 453 |
+
try:
|
| 454 |
+
import torch
|
| 455 |
+
compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE, torch.float16)
|
| 456 |
+
quant_config = BitsAndBytesConfig(
|
| 457 |
+
load_in_4bit=True,
|
| 458 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 459 |
+
bnb_4bit_use_double_quant=BNB_4BIT_USE_DOUBLE_QUANT,
|
| 460 |
+
bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE,
|
| 461 |
+
)
|
| 462 |
+
model_kwargs["quantization_config"] = quant_config
|
| 463 |
+
_log(f"Using 4-bit quantization: {BNB_4BIT_QUANT_TYPE}, compute_dtype={BNB_4BIT_COMPUTE_DTYPE}, double_quant={BNB_4BIT_USE_DOUBLE_QUANT}")
|
| 464 |
+
except Exception as e:
|
| 465 |
+
_log(f"BitsAndBytes quantization failed: {e}; falling back to full precision")
|
| 466 |
+
|
| 467 |
# Device and dtype resolution
|
| 468 |
try:
|
| 469 |
import torch # local import to avoid heavy import at module load
|
requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@ python-multipart>=0.0.6
|
|
| 6 |
# HF ecosystem
|
| 7 |
transformers>=4.44.0
|
| 8 |
accelerate>=0.33.0
|
|
|
|
| 9 |
|
| 10 |
# Multimedia + utils
|
| 11 |
pillow>=10.0.0
|
|
|
|
| 6 |
# HF ecosystem
|
| 7 |
transformers>=4.44.0
|
| 8 |
accelerate>=0.33.0
|
| 9 |
+
bitsandbytes>=0.43.0
|
| 10 |
|
| 11 |
# Multimedia + utils
|
| 12 |
pillow>=10.0.0
|