devjas1 commited on
Commit
ec8ba60
·
1 Parent(s): 20bca60

(CHORE+FIX)[Remove Async]: Remove async inference utilities for polymer classification

Browse files
Files changed (1) hide show
  1. utils/async_inference.py +0 -254
utils/async_inference.py DELETED
@@ -1,254 +0,0 @@
1
- """
2
- Asynchronous inference utilities for polymer classification.
3
- Supports async processing for improved UI responsiveness.
4
- """
5
-
6
- import concurrent.futures
7
- import time
8
- from typing import Dict, Any, List, Optional, Callable
9
- from dataclasses import dataclass
10
- from enum import Enum
11
- import streamlit as st
12
- import numpy as np
13
-
14
-
15
- class InferenceStatus(Enum):
16
- """Enumeration of possible statuses for an inference task."""
17
-
18
- PENDING = "pending"
19
- RUNNING = "running"
20
- COMPLETED = "completed"
21
- FAILED = "failed"
22
-
23
-
24
- @dataclass
25
- class InferenceTask:
26
- """Represents an asynchronous inference task."""
27
-
28
- task_id: str
29
- model_name: str
30
- input_data: np.ndarray
31
- status: InferenceStatus = InferenceStatus.PENDING
32
- result: Optional[Dict[str, Any]] = None
33
- error: Optional[str] = None
34
- start_time: Optional[float] = None
35
- end_time: Optional[float] = None
36
-
37
- @property
38
- def duration(self) -> Optional[float]:
39
- if self.start_time and self.end_time:
40
- return self.end_time - self.start_time
41
- return None
42
-
43
-
44
- class AsyncInferenceManager:
45
- """Manages asynchronous inference tasks for multiple models."""
46
-
47
- def __init__(self, max_workers: int = 3):
48
- self.max_workers = max_workers
49
- self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
50
- self.tasks: Dict[str, InferenceTask] = {}
51
- self._task_counter = 0
52
-
53
- def generate_task_id(self) -> str:
54
- """Generate unique task ID."""
55
- self._task_counter += 1
56
- return f"task_{self._task_counter}_{int(time.time() * 1000)}"
57
-
58
- def submit_inference(
59
- self,
60
- model_name: str,
61
- input_data: np.ndarray,
62
- inference_func: Callable,
63
- **kwargs,
64
- ) -> str:
65
- """Submit an inference task for asynchronous execution."""
66
- task_id = self.generate_task_id()
67
- task = InferenceTask(
68
- task_id=task_id, model_name=model_name, input_data=input_data
69
- )
70
-
71
- self.tasks[task_id] = task
72
-
73
- # Submit to thread pool
74
- self.executor.submit(self._run_inference, task, inference_func, **kwargs)
75
-
76
- return task_id
77
-
78
- def _run_inference(
79
- self, task: InferenceTask, inference_func: Callable, **kwargs
80
- ) -> None:
81
- """Execute inference task."""
82
- try:
83
- task.status = InferenceStatus.RUNNING
84
- task.start_time = time.time()
85
-
86
- # Run inference
87
- result = inference_func(task.input_data, task.model_name, **kwargs)
88
-
89
- task.result = result
90
- task.status = InferenceStatus.COMPLETED
91
- task.end_time = time.time()
92
-
93
- except (
94
- ValueError,
95
- TypeError,
96
- RuntimeError,
97
- ) as e: # Replace with specific exceptions
98
- task.error = str(e)
99
- task.status = InferenceStatus.FAILED
100
- task.end_time = time.time()
101
-
102
- def get_task_status(self, task_id: str) -> Optional[InferenceTask]:
103
- """Get status of a specific task."""
104
- return self.tasks.get(task_id)
105
-
106
- def get_completed_tasks(self) -> List[InferenceTask]:
107
- """Get all completed tasks."""
108
- return [
109
- task
110
- for task in self.tasks.values()
111
- if task.status == InferenceStatus.COMPLETED
112
- ]
113
-
114
- def get_failed_tasks(self) -> List[InferenceTask]:
115
- """Get all failed tasks."""
116
- return [
117
- task
118
- for task in self.tasks.values()
119
- if task.status == InferenceStatus.FAILED
120
- ]
121
-
122
- def wait_for_completion(self, task_ids: List[str], timeout: float = 30.0) -> bool:
123
- """Wait for specific tasks to complete."""
124
- start_time = time.time()
125
- while time.time() - start_time < timeout:
126
- all_done = all(
127
- self.tasks[tid].status
128
- in [InferenceStatus.COMPLETED, InferenceStatus.FAILED]
129
- for tid in task_ids
130
- if tid in self.tasks
131
- )
132
- if all_done:
133
- return True
134
- time.sleep(0.1)
135
- return False
136
-
137
- def cleanup_completed_tasks(self, max_age: float = 300.0) -> None:
138
- """Clean up old completed tasks."""
139
- current_time = time.time()
140
- to_remove = []
141
-
142
- for task_id, task in self.tasks.items():
143
- if (
144
- task.end_time
145
- and current_time - task.end_time > max_age
146
- and task.status in [InferenceStatus.COMPLETED, InferenceStatus.FAILED]
147
- ):
148
- to_remove.append(task_id)
149
-
150
- for task_id in to_remove:
151
- del self.tasks[task_id]
152
-
153
- def shutdown(self):
154
- """Shutdown the executor."""
155
- self.executor.shutdown(wait=True)
156
-
157
-
158
- class AsyncInferenceManagerSingleton:
159
- """Singleton wrapper for AsyncInferenceManager."""
160
-
161
- _instance: Optional[AsyncInferenceManager] = None
162
-
163
- @classmethod
164
- def get_instance(cls) -> AsyncInferenceManager:
165
- """Get the singleton instance of AsyncInferenceManager."""
166
- if cls._instance is None:
167
- cls._instance = AsyncInferenceManager()
168
- return cls._instance
169
-
170
-
171
- def get_async_inference_manager() -> AsyncInferenceManager:
172
- """Get or create the singleton async inference manager."""
173
- return AsyncInferenceManagerSingleton.get_instance()
174
-
175
-
176
- @st.cache_resource
177
- def get_cached_async_manager():
178
- """Get cached async inference manager for Streamlit."""
179
- return AsyncInferenceManager()
180
-
181
-
182
- def submit_batch_inference(
183
- model_names: List[str], input_data: np.ndarray, inference_func: Callable, **kwargs
184
- ) -> List[str]:
185
- """Submit batch inference for multiple models."""
186
- manager = get_async_inference_manager()
187
- task_ids = []
188
-
189
- for model_name in model_names:
190
- task_id = manager.submit_inference(
191
- model_name=model_name,
192
- input_data=input_data,
193
- inference_func=inference_func,
194
- **kwargs,
195
- )
196
- task_ids.append(task_id)
197
-
198
- return task_ids
199
-
200
-
201
- def check_inference_progress(task_ids: List[str]) -> Dict[str, Dict[str, Any]]:
202
- """Check progress of multiple inference tasks."""
203
- manager = get_async_inference_manager()
204
- progress = {}
205
-
206
- for task_id in task_ids:
207
- task = manager.get_task_status(task_id)
208
- if task:
209
- progress[task_id] = {
210
- "model_name": task.model_name,
211
- "status": task.status.value,
212
- "duration": task.duration,
213
- "error": task.error,
214
- }
215
-
216
- return progress
217
-
218
-
219
- def wait_for_batch_completion(
220
- task_ids: List[str],
221
- timeout: float = 30.0,
222
- progress_callback: Optional[Callable] = None,
223
- ) -> Dict[str, Any]:
224
- """Wait for batch inference completion with progress updates."""
225
- manager = get_async_inference_manager()
226
- start_time = time.time()
227
-
228
- while time.time() - start_time < timeout:
229
- progress = check_inference_progress(task_ids)
230
-
231
- if progress_callback:
232
- progress_callback(progress)
233
-
234
- # Check if all tasks are done
235
- all_done = all(
236
- status["status"] in ["completed", "failed"] for status in progress.values()
237
- )
238
-
239
- if all_done:
240
- break
241
-
242
- time.sleep(0.2)
243
-
244
- # Collect results
245
- results = {}
246
- for task_id in task_ids:
247
- task = manager.get_task_status(task_id)
248
- if task:
249
- if task.status == InferenceStatus.COMPLETED:
250
- results[task.model_name] = task.result
251
- elif task.status == InferenceStatus.FAILED:
252
- results[task.model_name] = {"error": task.error}
253
-
254
- return results