tommulder commited on
Commit
d405999
·
1 Parent(s): ab66d7a

docs(model_loader): add comprehensive docstrings and comments; no functional changes

Browse files
Files changed (1) hide show
  1. src/kybtech_dots_ocr/model_loader.py +159 -90
src/kybtech_dots_ocr/model_loader.py CHANGED
@@ -1,8 +1,24 @@
1
  """Dots.OCR Model Loader
2
 
3
- This module handles downloading and loading the Dots.OCR model using Hugging Face's
4
- snapshot_download functionality. It provides device selection, dtype configuration,
5
- and model initialization with proper error handling.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  import os
@@ -19,16 +35,21 @@ from PIL import Image
19
  logger = logging.getLogger(__name__)
20
 
21
  # Environment variable configuration
 
 
 
 
22
  REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr")
23
  LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
24
- DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto")
25
  MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
26
- USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "0") == "1"
27
- MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56
28
- MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360
29
  CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
30
 
31
- # Default transcription prompt for faithful text extraction
 
32
  DEFAULT_PROMPT = (
33
  "Transcribe all visible text in the image in the original language. "
34
  "Do not translate. Preserve natural reading order. Output plain text only."
@@ -36,19 +57,34 @@ DEFAULT_PROMPT = (
36
 
37
 
38
  class DotsOCRModelLoader:
39
- """Handles Dots.OCR model downloading, loading, and inference."""
40
-
 
 
 
 
 
41
  def __init__(self):
42
- """Initialize the model loader."""
 
 
 
 
43
  self.model = None
44
  self.processor = None
45
  self.device = None
46
  self.dtype = None
47
  self.local_dir = None
48
  self.prompt = CUSTOM_PROMPT or DEFAULT_PROMPT
49
-
50
  def _determine_device_and_dtype(self) -> Tuple[str, torch.dtype]:
51
- """Determine the best device and dtype based on availability and configuration."""
 
 
 
 
 
 
52
  if DEVICE_CONFIG == "cpu":
53
  device = "cpu"
54
  dtype = torch.float32
@@ -67,35 +103,42 @@ class DotsOCRModelLoader:
67
  logger.warning(f"CUDA requested but not available, falling back to CPU")
68
  device = "cpu"
69
  dtype = torch.float32
70
-
71
  logger.info(f"Selected device: {device}, dtype: {dtype}")
72
  return device, dtype
73
-
74
  def _download_model(self) -> str:
75
- """Download the model using snapshot_download."""
 
 
 
 
76
  logger.info(f"Downloading model from {REPO_ID} to {LOCAL_DIR}")
77
-
78
  try:
79
  # Ensure local directory exists
80
  Path(LOCAL_DIR).mkdir(parents=True, exist_ok=True)
81
-
82
  # Download model snapshot
83
  local_path = snapshot_download(
84
  repo_id=REPO_ID,
85
  local_dir=LOCAL_DIR,
86
  )
87
-
88
  logger.info(f"Model downloaded successfully to {local_path}")
89
  return local_path
90
-
91
  except Exception as e:
92
  logger.error(f"Failed to download model: {e}")
93
  raise RuntimeError(f"Model download failed: {e}")
94
 
95
  def _can_use_flash_attn(self) -> bool:
96
  """Check whether FlashAttention2 can be enabled safely.
97
-
98
- Returns True only if the package is importable and dtype is fp16/bf16.
 
 
 
99
  """
100
  if not USE_FLASH_ATTENTION:
101
  return False
@@ -103,33 +146,42 @@ class DotsOCRModelLoader:
103
  # Import check avoids runtime error from Transformers if not installed
104
  import flash_attn # type: ignore # noqa: F401
105
  except Exception:
106
- logger.warning("flash_attn package not installed; disabling FlashAttention2")
 
 
107
  return False
108
  # FlashAttention2 supports fp16/bf16 only (see HF docs)
109
  return self.dtype in (torch.float16, torch.bfloat16)
110
-
111
  def load_model(self) -> None:
112
- """Load the Dots.OCR model and processor."""
 
 
 
 
 
 
 
 
113
  try:
114
  # Determine device and dtype
115
  self.device, self.dtype = self._determine_device_and_dtype()
116
-
117
  # Download model if not already present
118
  self.local_dir = self._download_model()
119
-
120
  # Load processor
121
  logger.info("Loading processor...")
122
  self.processor = AutoProcessor.from_pretrained(
123
- self.local_dir,
124
- trust_remote_code=True
125
  )
126
-
127
  # Load model with appropriate configuration
128
  model_kwargs = {
129
- "dtype": self.dtype, # torch_dtype is deprecated
130
  "trust_remote_code": True,
131
  }
132
-
133
  # Add device-specific configurations
134
  if self.device == "cuda":
135
  # Prefer FlashAttention2 when truly available; otherwise SDPA
@@ -138,40 +190,46 @@ class DotsOCRModelLoader:
138
  logger.info("Using flash attention 2")
139
  else:
140
  model_kwargs["attn_implementation"] = "sdpa"
141
- logger.info("Using SDPA attention (flash-attn unavailable or disabled)")
 
 
142
 
143
  # Use device_map for automatic GPU memory management
144
  model_kwargs["device_map"] = "auto"
145
  else:
146
  # For CPU, don't use device_map
147
  model_kwargs["device_map"] = None
148
-
149
  logger.info("Loading model...")
150
  self.model = AutoModelForCausalLM.from_pretrained(
151
- self.local_dir,
152
- **model_kwargs
153
  )
154
-
155
  # Move model to device if not using device_map
156
  if self.device == "cpu" or model_kwargs.get("device_map") is None:
157
  self.model = self.model.to(self.device)
158
-
159
  logger.info(f"Model loaded successfully on {self.device}")
160
-
161
  except Exception as e:
162
  logger.error(f"Failed to load model: {e}")
163
  raise RuntimeError(f"Model loading failed: {e}")
164
-
165
  def _preprocess_image(self, image: Image.Image) -> Image.Image:
166
- """Preprocess image to meet model requirements."""
 
 
 
 
 
167
  # Convert to RGB if necessary
168
  if image.mode != "RGB":
169
  image = image.convert("RGB")
170
-
171
  # Calculate current pixel count
172
  width, height = image.size
173
  current_pixels = width * height
174
-
175
  # Resize if necessary to meet pixel requirements
176
  if current_pixels < MIN_PIXELS:
177
  # Scale up to meet minimum pixel requirement
@@ -179,104 +237,115 @@ class DotsOCRModelLoader:
179
  new_width = int(width * scale_factor)
180
  new_height = int(height * scale_factor)
181
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
182
- logger.info(f"Scaled up image from {width}x{height} to {new_width}x{new_height}")
183
-
 
 
184
  elif current_pixels > MAX_PIXELS:
185
  # Scale down to meet maximum pixel requirement
186
  scale_factor = (MAX_PIXELS / current_pixels) ** 0.5
187
  new_width = int(width * scale_factor)
188
  new_height = int(height * scale_factor)
189
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
190
- logger.info(f"Scaled down image from {width}x{height} to {new_width}x{new_height}")
191
-
 
 
192
  # Ensure dimensions are divisible by 28 (common requirement for vision models)
193
  width, height = image.size
194
  new_width = ((width + 27) // 28) * 28
195
  new_height = ((height + 27) // 28) * 28
196
-
197
  if new_width != width or new_height != height:
198
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
199
- logger.info(f"Adjusted image dimensions to be divisible by 28: {new_width}x{new_height}")
200
-
 
 
201
  return image
202
-
203
  @torch.inference_mode()
204
  def extract_text(self, image: Image.Image, prompt: Optional[str] = None) -> str:
205
- """Extract text from an image using the loaded model."""
 
 
 
 
206
  if self.model is None or self.processor is None:
207
  raise RuntimeError("Model not loaded. Call load_model() first.")
208
-
209
  try:
210
  # Preprocess image
211
  processed_image = self._preprocess_image(image)
212
-
213
  # Use provided prompt or default
214
  text_prompt = prompt or self.prompt
215
-
216
  # Prepare messages for the model
217
- messages = [{
218
- "role": "user",
219
- "content": [
220
- {"type": "image", "image": processed_image},
221
- {"type": "text", "text": text_prompt},
222
- ],
223
- }]
224
-
225
- # Apply chat template
 
 
226
  text = self.processor.apply_chat_template(
227
- messages,
228
- tokenize=False,
229
- add_generation_prompt=True
230
  )
231
-
232
  # Process vision information (required for some models)
233
  try:
234
  from qwen_vl_utils import process_vision_info
 
235
  image_inputs, video_inputs = process_vision_info(messages)
236
  except ImportError:
237
  # Fallback if qwen_vl_utils not available
238
  logger.warning("qwen_vl_utils not available, using basic processing")
239
  image_inputs = [processed_image]
240
  video_inputs = []
241
-
242
  # Prepare inputs
243
  inputs = self.processor(
244
- text=[text],
245
- images=image_inputs,
246
- videos=video_inputs,
247
- padding=True,
248
- return_tensors="pt"
249
  ).to(self.device)
250
-
251
- # Generate text
252
  output_ids = self.model.generate(
253
  **inputs,
254
  max_new_tokens=MAX_NEW_TOKENS,
255
  do_sample=False,
256
  temperature=0.0,
257
- pad_token_id=self.processor.tokenizer.eos_token_id
258
  )
259
-
260
- # Decode output
261
- trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, output_ids)]
 
 
262
  decoded = self.processor.batch_decode(
263
- trimmed,
264
- skip_special_tokens=True,
265
- clean_up_tokenization_spaces=False
266
  )
267
-
268
  return decoded[0] if decoded else ""
269
-
270
  except Exception as e:
271
  logger.error(f"Text extraction failed: {e}")
272
  raise RuntimeError(f"Text extraction failed: {e}")
273
-
274
  def is_loaded(self) -> bool:
275
- """Check if the model is loaded and ready for inference."""
276
  return self.model is not None and self.processor is not None
277
-
278
  def get_model_info(self) -> Dict[str, Any]:
279
- """Get information about the loaded model."""
280
  return {
281
  "device": self.device,
282
  "dtype": str(self.dtype),
@@ -285,7 +354,7 @@ class DotsOCRModelLoader:
285
  "max_new_tokens": MAX_NEW_TOKENS,
286
  "use_flash_attention": USE_FLASH_ATTENTION,
287
  "prompt": self.prompt,
288
- "is_loaded": self.is_loaded()
289
  }
290
 
291
 
 
1
  """Dots.OCR Model Loader
2
 
3
+ This module handles downloading and loading the Dots.OCR model using
4
+ Hugging Face's `snapshot_download`. It centralizes device selection,
5
+ dtype configuration, model initialization, and safe fallbacks.
6
+
7
+ Why this exists:
8
+ - Keep model lifecycle and I/O concerns isolated from API/business logic.
9
+ - Provide safe CPU defaults, optional CUDA acceleration, and optional
10
+ FlashAttention2 when compatible and explicitly enabled.
11
+
12
+ Key environment variables:
13
+ - DOTS_OCR_REPO_ID: HF repo to download (default: "rednote-hilab/dots.ocr").
14
+ - DOTS_OCR_LOCAL_DIR: Local cache directory for `snapshot_download`.
15
+ - DOTS_OCR_DEVICE: One of {"cpu", "cuda", "auto"}. "auto" prefers CUDA.
16
+ - DOTS_OCR_MAX_NEW_TOKENS: Max generated tokens per request.
17
+ - DOTS_OCR_FLASH_ATTENTION: "1" to attempt FlashAttention2 when compatible.
18
+ - DOTS_OCR_MIN_PIXELS / DOTS_OCR_MAX_PIXELS: Image size bounds pre-inference.
19
+ - DOTS_OCR_PROMPT: Optional default transcription prompt.
20
+
21
+ Usage: call `load_model()` once, then `extract_text(image)` per request.
22
  """
23
 
24
  import os
 
35
  logger = logging.getLogger(__name__)
36
 
37
  # Environment variable configuration
38
+ #
39
+ # These env vars make runtime behavior tunable without code changes. Defaults are
40
+ # conservative to favor stability on CPU-only platforms; performance features
41
+ # are opt-in and gated by compatibility checks.
42
  REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr")
43
  LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
44
+ DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto") # "auto" prefers CUDA if available
45
  MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
46
+ USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "0") == "1" # opt-in
47
+ MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56 lower bound
48
+ MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360 upper bound
49
  CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
50
 
51
+ # Default transcription prompt for faithful text extraction.
52
+ # Keep terse to reduce bias; we want faithful extraction, not translation or formatting.
53
  DEFAULT_PROMPT = (
54
  "Transcribe all visible text in the image in the original language. "
55
  "Do not translate. Preserve natural reading order. Output plain text only."
 
57
 
58
 
59
  class DotsOCRModelLoader:
60
+ """Handles Dots.OCR model downloading, loading, and inference.
61
+
62
+ Encapsulates model lifecycle (download, init, device placement), preprocessing,
63
+ and a narrow inference surface for OCR. Exposes a minimal API and maintains a
64
+ single global instance via helpers below.
65
+ """
66
+
67
  def __init__(self):
68
+ """Initialize the model loader.
69
+
70
+ Heavyweight work is deferred until `load_model()` so that constructing this
71
+ class is cheap. The default prompt is captured from env, if provided.
72
+ """
73
  self.model = None
74
  self.processor = None
75
  self.device = None
76
  self.dtype = None
77
  self.local_dir = None
78
  self.prompt = CUSTOM_PROMPT or DEFAULT_PROMPT
79
+
80
  def _determine_device_and_dtype(self) -> Tuple[str, torch.dtype]:
81
+ """Pick device and dtype based on availability and configuration.
82
+
83
+ Rules:
84
+ - Respect explicit "cpu" or "cuda" when valid.
85
+ - "auto" selects CUDA when available, else CPU.
86
+ - Use bfloat16 on CUDA for throughput; float32 on CPU for correctness.
87
+ """
88
  if DEVICE_CONFIG == "cpu":
89
  device = "cpu"
90
  dtype = torch.float32
 
103
  logger.warning(f"CUDA requested but not available, falling back to CPU")
104
  device = "cpu"
105
  dtype = torch.float32
106
+
107
  logger.info(f"Selected device: {device}, dtype: {dtype}")
108
  return device, dtype
109
+
110
  def _download_model(self) -> str:
111
+ """Download the model using `snapshot_download` and ensure cache dir exists.
112
+
113
+ Returns the resolved local path for deterministic, offline-friendly loading.
114
+ Raises `RuntimeError` on failure.
115
+ """
116
  logger.info(f"Downloading model from {REPO_ID} to {LOCAL_DIR}")
117
+
118
  try:
119
  # Ensure local directory exists
120
  Path(LOCAL_DIR).mkdir(parents=True, exist_ok=True)
121
+
122
  # Download model snapshot
123
  local_path = snapshot_download(
124
  repo_id=REPO_ID,
125
  local_dir=LOCAL_DIR,
126
  )
127
+
128
  logger.info(f"Model downloaded successfully to {local_path}")
129
  return local_path
130
+
131
  except Exception as e:
132
  logger.error(f"Failed to download model: {e}")
133
  raise RuntimeError(f"Model download failed: {e}")
134
 
135
  def _can_use_flash_attn(self) -> bool:
136
  """Check whether FlashAttention2 can be enabled safely.
137
+
138
+ Requires all of:
139
+ - DOTS_OCR_FLASH_ATTENTION toggle is set.
140
+ - `flash_attn` is importable.
141
+ - dtype is fp16/bf16 per library support.
142
  """
143
  if not USE_FLASH_ATTENTION:
144
  return False
 
146
  # Import check avoids runtime error from Transformers if not installed
147
  import flash_attn # type: ignore # noqa: F401
148
  except Exception:
149
+ logger.warning(
150
+ "flash_attn package not installed; disabling FlashAttention2"
151
+ )
152
  return False
153
  # FlashAttention2 supports fp16/bf16 only (see HF docs)
154
  return self.dtype in (torch.float16, torch.bfloat16)
155
+
156
  def load_model(self) -> None:
157
+ """Load the Dots.OCR model and processor.
158
+
159
+ Steps:
160
+ 1) Determine device/dtype
161
+ 2) Download snapshot if missing
162
+ 3) Load `AutoProcessor`
163
+ 4) Configure attention/device mapping
164
+ 5) Instantiate model and place on target device
165
+ """
166
  try:
167
  # Determine device and dtype
168
  self.device, self.dtype = self._determine_device_and_dtype()
169
+
170
  # Download model if not already present
171
  self.local_dir = self._download_model()
172
+
173
  # Load processor
174
  logger.info("Loading processor...")
175
  self.processor = AutoProcessor.from_pretrained(
176
+ self.local_dir, trust_remote_code=True
 
177
  )
178
+
179
  # Load model with appropriate configuration
180
  model_kwargs = {
181
+ "dtype": self.dtype, # NOTE: `torch_dtype` is deprecated upstream
182
  "trust_remote_code": True,
183
  }
184
+
185
  # Add device-specific configurations
186
  if self.device == "cuda":
187
  # Prefer FlashAttention2 when truly available; otherwise SDPA
 
190
  logger.info("Using flash attention 2")
191
  else:
192
  model_kwargs["attn_implementation"] = "sdpa"
193
+ logger.info(
194
+ "Using SDPA attention (flash-attn unavailable or disabled)"
195
+ )
196
 
197
  # Use device_map for automatic GPU memory management
198
  model_kwargs["device_map"] = "auto"
199
  else:
200
  # For CPU, don't use device_map
201
  model_kwargs["device_map"] = None
202
+
203
  logger.info("Loading model...")
204
  self.model = AutoModelForCausalLM.from_pretrained(
205
+ self.local_dir, **model_kwargs
 
206
  )
207
+
208
  # Move model to device if not using device_map
209
  if self.device == "cpu" or model_kwargs.get("device_map") is None:
210
  self.model = self.model.to(self.device)
211
+
212
  logger.info(f"Model loaded successfully on {self.device}")
213
+
214
  except Exception as e:
215
  logger.error(f"Failed to load model: {e}")
216
  raise RuntimeError(f"Model loading failed: {e}")
217
+
218
  def _preprocess_image(self, image: Image.Image) -> Image.Image:
219
+ """Preprocess image to meet model requirements.
220
+
221
+ - Normalize to RGB
222
+ - Constrain pixel count within [MIN_PIXELS, MAX_PIXELS]
223
+ - Snap dimensions to multiples of 28 to satisfy backbone constraints
224
+ """
225
  # Convert to RGB if necessary
226
  if image.mode != "RGB":
227
  image = image.convert("RGB")
228
+
229
  # Calculate current pixel count
230
  width, height = image.size
231
  current_pixels = width * height
232
+
233
  # Resize if necessary to meet pixel requirements
234
  if current_pixels < MIN_PIXELS:
235
  # Scale up to meet minimum pixel requirement
 
237
  new_width = int(width * scale_factor)
238
  new_height = int(height * scale_factor)
239
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
240
+ logger.info(
241
+ f"Scaled up image from {width}x{height} to {new_width}x{new_height}"
242
+ )
243
+
244
  elif current_pixels > MAX_PIXELS:
245
  # Scale down to meet maximum pixel requirement
246
  scale_factor = (MAX_PIXELS / current_pixels) ** 0.5
247
  new_width = int(width * scale_factor)
248
  new_height = int(height * scale_factor)
249
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
250
+ logger.info(
251
+ f"Scaled down image from {width}x{height} to {new_width}x{new_height}"
252
+ )
253
+
254
  # Ensure dimensions are divisible by 28 (common requirement for vision models)
255
  width, height = image.size
256
  new_width = ((width + 27) // 28) * 28
257
  new_height = ((height + 27) // 28) * 28
258
+
259
  if new_width != width or new_height != height:
260
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
261
+ logger.info(
262
+ f"Adjusted image dimensions to be divisible by 28: {new_width}x{new_height}"
263
+ )
264
+
265
  return image
266
+
267
  @torch.inference_mode()
268
  def extract_text(self, image: Image.Image, prompt: Optional[str] = None) -> str:
269
+ """Extract text from an image using the loaded model.
270
+
271
+ Builds a single-turn chat message with the image and a transcription prompt,
272
+ applies the model's chat template, and decodes deterministically (no sampling).
273
+ """
274
  if self.model is None or self.processor is None:
275
  raise RuntimeError("Model not loaded. Call load_model() first.")
276
+
277
  try:
278
  # Preprocess image
279
  processed_image = self._preprocess_image(image)
280
+
281
  # Use provided prompt or default
282
  text_prompt = prompt or self.prompt
283
+
284
  # Prepare messages for the model
285
+ messages = [
286
+ {
287
+ "role": "user",
288
+ "content": [
289
+ {"type": "image", "image": processed_image},
290
+ {"type": "text", "text": text_prompt},
291
+ ],
292
+ }
293
+ ]
294
+
295
+ # Apply chat template (preserves special tokens/formatting expected by model)
296
  text = self.processor.apply_chat_template(
297
+ messages, tokenize=False, add_generation_prompt=True
 
 
298
  )
299
+
300
  # Process vision information (required for some models)
301
  try:
302
  from qwen_vl_utils import process_vision_info
303
+
304
  image_inputs, video_inputs = process_vision_info(messages)
305
  except ImportError:
306
  # Fallback if qwen_vl_utils not available
307
  logger.warning("qwen_vl_utils not available, using basic processing")
308
  image_inputs = [processed_image]
309
  video_inputs = []
310
+
311
  # Prepare inputs
312
  inputs = self.processor(
313
+ text=[text],
314
+ images=image_inputs,
315
+ videos=video_inputs,
316
+ padding=True,
317
+ return_tensors="pt",
318
  ).to(self.device)
319
+
320
+ # Generate text deterministically (temperature=0, do_sample=False)
321
  output_ids = self.model.generate(
322
  **inputs,
323
  max_new_tokens=MAX_NEW_TOKENS,
324
  do_sample=False,
325
  temperature=0.0,
326
+ pad_token_id=self.processor.tokenizer.eos_token_id,
327
  )
328
+
329
+ # Decode only newly generated tokens (strip prompt tokens)
330
+ trimmed = [
331
+ out[len(inp) :] for inp, out in zip(inputs.input_ids, output_ids)
332
+ ]
333
  decoded = self.processor.batch_decode(
334
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
335
  )
336
+
337
  return decoded[0] if decoded else ""
338
+
339
  except Exception as e:
340
  logger.error(f"Text extraction failed: {e}")
341
  raise RuntimeError(f"Text extraction failed: {e}")
342
+
343
  def is_loaded(self) -> bool:
344
+ """Return True when both model and processor are initialized."""
345
  return self.model is not None and self.processor is not None
346
+
347
  def get_model_info(self) -> Dict[str, Any]:
348
+ """Get diagnostic information about the loaded model and configuration."""
349
  return {
350
  "device": self.device,
351
  "dtype": str(self.dtype),
 
354
  "max_new_tokens": MAX_NEW_TOKENS,
355
  "use_flash_attention": USE_FLASH_ATTENTION,
356
  "prompt": self.prompt,
357
+ "is_loaded": self.is_loaded(),
358
  }
359
 
360