KillerKing93 commited on
Commit
7e03eea
·
verified ·
1 Parent(s): 7456af6

Sync from GitHub 233d3a7

Browse files
Files changed (3) hide show
  1. .env.example +7 -0
  2. main.py +23 -1
  3. requirements.txt +1 -0
.env.example CHANGED
@@ -16,6 +16,13 @@ MAX_VIDEO_FRAMES=16
16
  # Transformers loading hints
17
  DEVICE_MAP=auto
18
  TORCH_DTYPE=auto
 
 
 
 
 
 
 
19
  # Persistent SSE session store (SQLite)
20
  # Enable to persist streaming chunks per session_id and allow resume after server restarts.
21
  # 1=true, 0=false
 
16
  # Transformers loading hints
17
  DEVICE_MAP=auto
18
  TORCH_DTYPE=auto
19
+
20
+ # Quantization config (BitsAndBytes 4-bit)
21
+ # Enable 4-bit quantization to reduce VRAM usage (~5GB -> ~1.5GB)
22
+ LOAD_IN_4BIT=1
23
+ BNB_4BIT_COMPUTE_DTYPE=float16
24
+ BNB_4BIT_USE_DOUBLE_QUANT=1
25
+ BNB_4BIT_QUANT_TYPE=nf4
26
  # Persistent SSE session store (SQLite)
27
  # Enable to persist streaming chunks per session_id and allow resume after server restarts.
28
  # 1=true, 0=false
main.py CHANGED
@@ -79,6 +79,12 @@ MAX_VIDEO_FRAMES = int(os.getenv("MAX_VIDEO_FRAMES", "16"))
79
  DEVICE_MAP = os.getenv("DEVICE_MAP", "auto")
80
  TORCH_DTYPE = os.getenv("TORCH_DTYPE", "auto")
81
 
 
 
 
 
 
 
82
  # Persistent session store (SQLite)
83
  PERSIST_SESSIONS = str(os.getenv("PERSIST_SESSIONS", "0")).lower() in ("1", "true", "yes", "y")
84
  SESSIONS_DB_PATH = os.getenv("SESSIONS_DB_PATH", "sessions.db")
@@ -426,7 +432,7 @@ class CancelResponse(BaseModel):
426
  class Engine:
427
  def __init__(self, model_id: str, hf_token: Optional[str] = None):
428
  # Lazy import heavy deps
429
- from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel
430
  # AutoModelForImageTextToText is the v5+ replacement for Vision2Seq in Transformers
431
  try:
432
  from transformers import AutoModelForImageTextToText # type: ignore
@@ -442,6 +448,22 @@ class Engine:
442
  # Only pass 'token' (use_auth_token is deprecated and causes conflicts)
443
  model_kwargs["token"] = hf_token
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  # Device and dtype resolution
446
  try:
447
  import torch # local import to avoid heavy import at module load
 
79
  DEVICE_MAP = os.getenv("DEVICE_MAP", "auto")
80
  TORCH_DTYPE = os.getenv("TORCH_DTYPE", "auto")
81
 
82
+ # Quantization config (BitsAndBytes)
83
+ LOAD_IN_4BIT = str(os.getenv("LOAD_IN_4BIT", "1")).lower() in ("1", "true", "yes", "y")
84
+ BNB_4BIT_COMPUTE_DTYPE = os.getenv("BNB_4BIT_COMPUTE_DTYPE", "float16")
85
+ BNB_4BIT_USE_DOUBLE_QUANT = str(os.getenv("BNB_4BIT_USE_DOUBLE_QUANT", "1")).lower() in ("1", "true", "yes", "y")
86
+ BNB_4BIT_QUANT_TYPE = os.getenv("BNB_4BIT_QUANT_TYPE", "nf4")
87
+
88
  # Persistent session store (SQLite)
89
  PERSIST_SESSIONS = str(os.getenv("PERSIST_SESSIONS", "0")).lower() in ("1", "true", "yes", "y")
90
  SESSIONS_DB_PATH = os.getenv("SESSIONS_DB_PATH", "sessions.db")
 
432
  class Engine:
433
  def __init__(self, model_id: str, hf_token: Optional[str] = None):
434
  # Lazy import heavy deps
435
+ from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel, BitsAndBytesConfig
436
  # AutoModelForImageTextToText is the v5+ replacement for Vision2Seq in Transformers
437
  try:
438
  from transformers import AutoModelForImageTextToText # type: ignore
 
448
  # Only pass 'token' (use_auth_token is deprecated and causes conflicts)
449
  model_kwargs["token"] = hf_token
450
 
451
+ # Add quantization config if enabled
452
+ if LOAD_IN_4BIT:
453
+ try:
454
+ import torch
455
+ compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE, torch.float16)
456
+ quant_config = BitsAndBytesConfig(
457
+ load_in_4bit=True,
458
+ bnb_4bit_compute_dtype=compute_dtype,
459
+ bnb_4bit_use_double_quant=BNB_4BIT_USE_DOUBLE_QUANT,
460
+ bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE,
461
+ )
462
+ model_kwargs["quantization_config"] = quant_config
463
+ _log(f"Using 4-bit quantization: {BNB_4BIT_QUANT_TYPE}, compute_dtype={BNB_4BIT_COMPUTE_DTYPE}, double_quant={BNB_4BIT_USE_DOUBLE_QUANT}")
464
+ except Exception as e:
465
+ _log(f"BitsAndBytes quantization failed: {e}; falling back to full precision")
466
+
467
  # Device and dtype resolution
468
  try:
469
  import torch # local import to avoid heavy import at module load
requirements.txt CHANGED
@@ -6,6 +6,7 @@ python-multipart>=0.0.6
6
  # HF ecosystem
7
  transformers>=4.44.0
8
  accelerate>=0.33.0
 
9
 
10
  # Multimedia + utils
11
  pillow>=10.0.0
 
6
  # HF ecosystem
7
  transformers>=4.44.0
8
  accelerate>=0.33.0
9
+ bitsandbytes>=0.43.0
10
 
11
  # Multimedia + utils
12
  pillow>=10.0.0