devjas1 commited on
Commit
a2a2664
·
1 Parent(s): b2201ae

(FEAT)[Utilities: Async Inference Engine]: Add asynchronous inference engine for multi-model processing

Browse files

- Enables concurrent inference across multiple models, improving UI responsiveness and scalability.
- Developed AsyncInferenceManager class for managing async tasks with ThreadPoolExecutor.
Supports submission, status tracking, completion, and cleanup of inference tasks.
- Supports async submission, status tracking, completion, and cleanup of inference tasks.
- Defined `InferenceTask` dataclass for encapsulating task metadata and results.
- Provided functions for batch submission, progress checking, and waiting for completion:
- `submit_batch_inference`: Batch submission for multiple models.
- `check_inference_progress`: Real-time progress/status tracking.
- `wait_for_batch_completion`: Waits for all tasks with optional progress callback.
- Integrated Streamlit resource caching for manager instance.

Files changed (1) hide show
  1. utils/async_inference.py +254 -0
utils/async_inference.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Asynchronous inference utilities for polymer classification.
3
+ Supports async processing for improved UI responsiveness.
4
+ """
5
+
6
+ import concurrent.futures
7
+ import time
8
+ from typing import Dict, Any, List, Optional, Callable
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ import streamlit as st
12
+ import numpy as np
13
+
14
+
15
+ class InferenceStatus(Enum):
16
+ """Enumeration of possible statuses for an inference task."""
17
+
18
+ PENDING = "pending"
19
+ RUNNING = "running"
20
+ COMPLETED = "completed"
21
+ FAILED = "failed"
22
+
23
+
24
+ @dataclass
25
+ class InferenceTask:
26
+ """Represents an asynchronous inference task."""
27
+
28
+ task_id: str
29
+ model_name: str
30
+ input_data: np.ndarray
31
+ status: InferenceStatus = InferenceStatus.PENDING
32
+ result: Optional[Dict[str, Any]] = None
33
+ error: Optional[str] = None
34
+ start_time: Optional[float] = None
35
+ end_time: Optional[float] = None
36
+
37
+ @property
38
+ def duration(self) -> Optional[float]:
39
+ if self.start_time and self.end_time:
40
+ return self.end_time - self.start_time
41
+ return None
42
+
43
+
44
+ class AsyncInferenceManager:
45
+ """Manages asynchronous inference tasks for multiple models."""
46
+
47
+ def __init__(self, max_workers: int = 3):
48
+ self.max_workers = max_workers
49
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
50
+ self.tasks: Dict[str, InferenceTask] = {}
51
+ self._task_counter = 0
52
+
53
+ def generate_task_id(self) -> str:
54
+ """Generate unique task ID."""
55
+ self._task_counter += 1
56
+ return f"task_{self._task_counter}_{int(time.time() * 1000)}"
57
+
58
+ def submit_inference(
59
+ self,
60
+ model_name: str,
61
+ input_data: np.ndarray,
62
+ inference_func: Callable,
63
+ **kwargs,
64
+ ) -> str:
65
+ """Submit an inference task for asynchronous execution."""
66
+ task_id = self.generate_task_id()
67
+ task = InferenceTask(
68
+ task_id=task_id, model_name=model_name, input_data=input_data
69
+ )
70
+
71
+ self.tasks[task_id] = task
72
+
73
+ # Submit to thread pool
74
+ self.executor.submit(self._run_inference, task, inference_func, **kwargs)
75
+
76
+ return task_id
77
+
78
+ def _run_inference(
79
+ self, task: InferenceTask, inference_func: Callable, **kwargs
80
+ ) -> None:
81
+ """Execute inference task."""
82
+ try:
83
+ task.status = InferenceStatus.RUNNING
84
+ task.start_time = time.time()
85
+
86
+ # Run inference
87
+ result = inference_func(task.input_data, task.model_name, **kwargs)
88
+
89
+ task.result = result
90
+ task.status = InferenceStatus.COMPLETED
91
+ task.end_time = time.time()
92
+
93
+ except (
94
+ ValueError,
95
+ TypeError,
96
+ RuntimeError,
97
+ ) as e: # Replace with specific exceptions
98
+ task.error = str(e)
99
+ task.status = InferenceStatus.FAILED
100
+ task.end_time = time.time()
101
+
102
+ def get_task_status(self, task_id: str) -> Optional[InferenceTask]:
103
+ """Get status of a specific task."""
104
+ return self.tasks.get(task_id)
105
+
106
+ def get_completed_tasks(self) -> List[InferenceTask]:
107
+ """Get all completed tasks."""
108
+ return [
109
+ task
110
+ for task in self.tasks.values()
111
+ if task.status == InferenceStatus.COMPLETED
112
+ ]
113
+
114
+ def get_failed_tasks(self) -> List[InferenceTask]:
115
+ """Get all failed tasks."""
116
+ return [
117
+ task
118
+ for task in self.tasks.values()
119
+ if task.status == InferenceStatus.FAILED
120
+ ]
121
+
122
+ def wait_for_completion(self, task_ids: List[str], timeout: float = 30.0) -> bool:
123
+ """Wait for specific tasks to complete."""
124
+ start_time = time.time()
125
+ while time.time() - start_time < timeout:
126
+ all_done = all(
127
+ self.tasks[tid].status
128
+ in [InferenceStatus.COMPLETED, InferenceStatus.FAILED]
129
+ for tid in task_ids
130
+ if tid in self.tasks
131
+ )
132
+ if all_done:
133
+ return True
134
+ time.sleep(0.1)
135
+ return False
136
+
137
+ def cleanup_completed_tasks(self, max_age: float = 300.0) -> None:
138
+ """Clean up old completed tasks."""
139
+ current_time = time.time()
140
+ to_remove = []
141
+
142
+ for task_id, task in self.tasks.items():
143
+ if (
144
+ task.end_time
145
+ and current_time - task.end_time > max_age
146
+ and task.status in [InferenceStatus.COMPLETED, InferenceStatus.FAILED]
147
+ ):
148
+ to_remove.append(task_id)
149
+
150
+ for task_id in to_remove:
151
+ del self.tasks[task_id]
152
+
153
+ def shutdown(self):
154
+ """Shutdown the executor."""
155
+ self.executor.shutdown(wait=True)
156
+
157
+
158
+ class AsyncInferenceManagerSingleton:
159
+ """Singleton wrapper for AsyncInferenceManager."""
160
+
161
+ _instance: Optional[AsyncInferenceManager] = None
162
+
163
+ @classmethod
164
+ def get_instance(cls) -> AsyncInferenceManager:
165
+ """Get the singleton instance of AsyncInferenceManager."""
166
+ if cls._instance is None:
167
+ cls._instance = AsyncInferenceManager()
168
+ return cls._instance
169
+
170
+
171
+ def get_async_inference_manager() -> AsyncInferenceManager:
172
+ """Get or create the singleton async inference manager."""
173
+ return AsyncInferenceManagerSingleton.get_instance()
174
+
175
+
176
+ @st.cache_resource
177
+ def get_cached_async_manager():
178
+ """Get cached async inference manager for Streamlit."""
179
+ return AsyncInferenceManager()
180
+
181
+
182
+ def submit_batch_inference(
183
+ model_names: List[str], input_data: np.ndarray, inference_func: Callable, **kwargs
184
+ ) -> List[str]:
185
+ """Submit batch inference for multiple models."""
186
+ manager = get_async_inference_manager()
187
+ task_ids = []
188
+
189
+ for model_name in model_names:
190
+ task_id = manager.submit_inference(
191
+ model_name=model_name,
192
+ input_data=input_data,
193
+ inference_func=inference_func,
194
+ **kwargs,
195
+ )
196
+ task_ids.append(task_id)
197
+
198
+ return task_ids
199
+
200
+
201
+ def check_inference_progress(task_ids: List[str]) -> Dict[str, Dict[str, Any]]:
202
+ """Check progress of multiple inference tasks."""
203
+ manager = get_async_inference_manager()
204
+ progress = {}
205
+
206
+ for task_id in task_ids:
207
+ task = manager.get_task_status(task_id)
208
+ if task:
209
+ progress[task_id] = {
210
+ "model_name": task.model_name,
211
+ "status": task.status.value,
212
+ "duration": task.duration,
213
+ "error": task.error,
214
+ }
215
+
216
+ return progress
217
+
218
+
219
+ def wait_for_batch_completion(
220
+ task_ids: List[str],
221
+ timeout: float = 30.0,
222
+ progress_callback: Optional[Callable] = None,
223
+ ) -> Dict[str, Any]:
224
+ """Wait for batch inference completion with progress updates."""
225
+ manager = get_async_inference_manager()
226
+ start_time = time.time()
227
+
228
+ while time.time() - start_time < timeout:
229
+ progress = check_inference_progress(task_ids)
230
+
231
+ if progress_callback:
232
+ progress_callback(progress)
233
+
234
+ # Check if all tasks are done
235
+ all_done = all(
236
+ status["status"] in ["completed", "failed"] for status in progress.values()
237
+ )
238
+
239
+ if all_done:
240
+ break
241
+
242
+ time.sleep(0.2)
243
+
244
+ # Collect results
245
+ results = {}
246
+ for task_id in task_ids:
247
+ task = manager.get_task_status(task_id)
248
+ if task:
249
+ if task.status == InferenceStatus.COMPLETED:
250
+ results[task.model_name] = task.result
251
+ elif task.status == InferenceStatus.FAILED:
252
+ results[task.model_name] = {"error": task.error}
253
+
254
+ return results