Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
ec8ba60
1
Parent(s):
20bca60
(CHORE+FIX)[Remove Async]: Remove async inference utilities for polymer classification
Browse files- utils/async_inference.py +0 -254
utils/async_inference.py
DELETED
|
@@ -1,254 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Asynchronous inference utilities for polymer classification.
|
| 3 |
-
Supports async processing for improved UI responsiveness.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import concurrent.futures
|
| 7 |
-
import time
|
| 8 |
-
from typing import Dict, Any, List, Optional, Callable
|
| 9 |
-
from dataclasses import dataclass
|
| 10 |
-
from enum import Enum
|
| 11 |
-
import streamlit as st
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class InferenceStatus(Enum):
|
| 16 |
-
"""Enumeration of possible statuses for an inference task."""
|
| 17 |
-
|
| 18 |
-
PENDING = "pending"
|
| 19 |
-
RUNNING = "running"
|
| 20 |
-
COMPLETED = "completed"
|
| 21 |
-
FAILED = "failed"
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@dataclass
|
| 25 |
-
class InferenceTask:
|
| 26 |
-
"""Represents an asynchronous inference task."""
|
| 27 |
-
|
| 28 |
-
task_id: str
|
| 29 |
-
model_name: str
|
| 30 |
-
input_data: np.ndarray
|
| 31 |
-
status: InferenceStatus = InferenceStatus.PENDING
|
| 32 |
-
result: Optional[Dict[str, Any]] = None
|
| 33 |
-
error: Optional[str] = None
|
| 34 |
-
start_time: Optional[float] = None
|
| 35 |
-
end_time: Optional[float] = None
|
| 36 |
-
|
| 37 |
-
@property
|
| 38 |
-
def duration(self) -> Optional[float]:
|
| 39 |
-
if self.start_time and self.end_time:
|
| 40 |
-
return self.end_time - self.start_time
|
| 41 |
-
return None
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class AsyncInferenceManager:
|
| 45 |
-
"""Manages asynchronous inference tasks for multiple models."""
|
| 46 |
-
|
| 47 |
-
def __init__(self, max_workers: int = 3):
|
| 48 |
-
self.max_workers = max_workers
|
| 49 |
-
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
| 50 |
-
self.tasks: Dict[str, InferenceTask] = {}
|
| 51 |
-
self._task_counter = 0
|
| 52 |
-
|
| 53 |
-
def generate_task_id(self) -> str:
|
| 54 |
-
"""Generate unique task ID."""
|
| 55 |
-
self._task_counter += 1
|
| 56 |
-
return f"task_{self._task_counter}_{int(time.time() * 1000)}"
|
| 57 |
-
|
| 58 |
-
def submit_inference(
|
| 59 |
-
self,
|
| 60 |
-
model_name: str,
|
| 61 |
-
input_data: np.ndarray,
|
| 62 |
-
inference_func: Callable,
|
| 63 |
-
**kwargs,
|
| 64 |
-
) -> str:
|
| 65 |
-
"""Submit an inference task for asynchronous execution."""
|
| 66 |
-
task_id = self.generate_task_id()
|
| 67 |
-
task = InferenceTask(
|
| 68 |
-
task_id=task_id, model_name=model_name, input_data=input_data
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
self.tasks[task_id] = task
|
| 72 |
-
|
| 73 |
-
# Submit to thread pool
|
| 74 |
-
self.executor.submit(self._run_inference, task, inference_func, **kwargs)
|
| 75 |
-
|
| 76 |
-
return task_id
|
| 77 |
-
|
| 78 |
-
def _run_inference(
|
| 79 |
-
self, task: InferenceTask, inference_func: Callable, **kwargs
|
| 80 |
-
) -> None:
|
| 81 |
-
"""Execute inference task."""
|
| 82 |
-
try:
|
| 83 |
-
task.status = InferenceStatus.RUNNING
|
| 84 |
-
task.start_time = time.time()
|
| 85 |
-
|
| 86 |
-
# Run inference
|
| 87 |
-
result = inference_func(task.input_data, task.model_name, **kwargs)
|
| 88 |
-
|
| 89 |
-
task.result = result
|
| 90 |
-
task.status = InferenceStatus.COMPLETED
|
| 91 |
-
task.end_time = time.time()
|
| 92 |
-
|
| 93 |
-
except (
|
| 94 |
-
ValueError,
|
| 95 |
-
TypeError,
|
| 96 |
-
RuntimeError,
|
| 97 |
-
) as e: # Replace with specific exceptions
|
| 98 |
-
task.error = str(e)
|
| 99 |
-
task.status = InferenceStatus.FAILED
|
| 100 |
-
task.end_time = time.time()
|
| 101 |
-
|
| 102 |
-
def get_task_status(self, task_id: str) -> Optional[InferenceTask]:
|
| 103 |
-
"""Get status of a specific task."""
|
| 104 |
-
return self.tasks.get(task_id)
|
| 105 |
-
|
| 106 |
-
def get_completed_tasks(self) -> List[InferenceTask]:
|
| 107 |
-
"""Get all completed tasks."""
|
| 108 |
-
return [
|
| 109 |
-
task
|
| 110 |
-
for task in self.tasks.values()
|
| 111 |
-
if task.status == InferenceStatus.COMPLETED
|
| 112 |
-
]
|
| 113 |
-
|
| 114 |
-
def get_failed_tasks(self) -> List[InferenceTask]:
|
| 115 |
-
"""Get all failed tasks."""
|
| 116 |
-
return [
|
| 117 |
-
task
|
| 118 |
-
for task in self.tasks.values()
|
| 119 |
-
if task.status == InferenceStatus.FAILED
|
| 120 |
-
]
|
| 121 |
-
|
| 122 |
-
def wait_for_completion(self, task_ids: List[str], timeout: float = 30.0) -> bool:
|
| 123 |
-
"""Wait for specific tasks to complete."""
|
| 124 |
-
start_time = time.time()
|
| 125 |
-
while time.time() - start_time < timeout:
|
| 126 |
-
all_done = all(
|
| 127 |
-
self.tasks[tid].status
|
| 128 |
-
in [InferenceStatus.COMPLETED, InferenceStatus.FAILED]
|
| 129 |
-
for tid in task_ids
|
| 130 |
-
if tid in self.tasks
|
| 131 |
-
)
|
| 132 |
-
if all_done:
|
| 133 |
-
return True
|
| 134 |
-
time.sleep(0.1)
|
| 135 |
-
return False
|
| 136 |
-
|
| 137 |
-
def cleanup_completed_tasks(self, max_age: float = 300.0) -> None:
|
| 138 |
-
"""Clean up old completed tasks."""
|
| 139 |
-
current_time = time.time()
|
| 140 |
-
to_remove = []
|
| 141 |
-
|
| 142 |
-
for task_id, task in self.tasks.items():
|
| 143 |
-
if (
|
| 144 |
-
task.end_time
|
| 145 |
-
and current_time - task.end_time > max_age
|
| 146 |
-
and task.status in [InferenceStatus.COMPLETED, InferenceStatus.FAILED]
|
| 147 |
-
):
|
| 148 |
-
to_remove.append(task_id)
|
| 149 |
-
|
| 150 |
-
for task_id in to_remove:
|
| 151 |
-
del self.tasks[task_id]
|
| 152 |
-
|
| 153 |
-
def shutdown(self):
|
| 154 |
-
"""Shutdown the executor."""
|
| 155 |
-
self.executor.shutdown(wait=True)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
class AsyncInferenceManagerSingleton:
|
| 159 |
-
"""Singleton wrapper for AsyncInferenceManager."""
|
| 160 |
-
|
| 161 |
-
_instance: Optional[AsyncInferenceManager] = None
|
| 162 |
-
|
| 163 |
-
@classmethod
|
| 164 |
-
def get_instance(cls) -> AsyncInferenceManager:
|
| 165 |
-
"""Get the singleton instance of AsyncInferenceManager."""
|
| 166 |
-
if cls._instance is None:
|
| 167 |
-
cls._instance = AsyncInferenceManager()
|
| 168 |
-
return cls._instance
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
def get_async_inference_manager() -> AsyncInferenceManager:
|
| 172 |
-
"""Get or create the singleton async inference manager."""
|
| 173 |
-
return AsyncInferenceManagerSingleton.get_instance()
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
@st.cache_resource
|
| 177 |
-
def get_cached_async_manager():
|
| 178 |
-
"""Get cached async inference manager for Streamlit."""
|
| 179 |
-
return AsyncInferenceManager()
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
def submit_batch_inference(
|
| 183 |
-
model_names: List[str], input_data: np.ndarray, inference_func: Callable, **kwargs
|
| 184 |
-
) -> List[str]:
|
| 185 |
-
"""Submit batch inference for multiple models."""
|
| 186 |
-
manager = get_async_inference_manager()
|
| 187 |
-
task_ids = []
|
| 188 |
-
|
| 189 |
-
for model_name in model_names:
|
| 190 |
-
task_id = manager.submit_inference(
|
| 191 |
-
model_name=model_name,
|
| 192 |
-
input_data=input_data,
|
| 193 |
-
inference_func=inference_func,
|
| 194 |
-
**kwargs,
|
| 195 |
-
)
|
| 196 |
-
task_ids.append(task_id)
|
| 197 |
-
|
| 198 |
-
return task_ids
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def check_inference_progress(task_ids: List[str]) -> Dict[str, Dict[str, Any]]:
|
| 202 |
-
"""Check progress of multiple inference tasks."""
|
| 203 |
-
manager = get_async_inference_manager()
|
| 204 |
-
progress = {}
|
| 205 |
-
|
| 206 |
-
for task_id in task_ids:
|
| 207 |
-
task = manager.get_task_status(task_id)
|
| 208 |
-
if task:
|
| 209 |
-
progress[task_id] = {
|
| 210 |
-
"model_name": task.model_name,
|
| 211 |
-
"status": task.status.value,
|
| 212 |
-
"duration": task.duration,
|
| 213 |
-
"error": task.error,
|
| 214 |
-
}
|
| 215 |
-
|
| 216 |
-
return progress
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def wait_for_batch_completion(
|
| 220 |
-
task_ids: List[str],
|
| 221 |
-
timeout: float = 30.0,
|
| 222 |
-
progress_callback: Optional[Callable] = None,
|
| 223 |
-
) -> Dict[str, Any]:
|
| 224 |
-
"""Wait for batch inference completion with progress updates."""
|
| 225 |
-
manager = get_async_inference_manager()
|
| 226 |
-
start_time = time.time()
|
| 227 |
-
|
| 228 |
-
while time.time() - start_time < timeout:
|
| 229 |
-
progress = check_inference_progress(task_ids)
|
| 230 |
-
|
| 231 |
-
if progress_callback:
|
| 232 |
-
progress_callback(progress)
|
| 233 |
-
|
| 234 |
-
# Check if all tasks are done
|
| 235 |
-
all_done = all(
|
| 236 |
-
status["status"] in ["completed", "failed"] for status in progress.values()
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
if all_done:
|
| 240 |
-
break
|
| 241 |
-
|
| 242 |
-
time.sleep(0.2)
|
| 243 |
-
|
| 244 |
-
# Collect results
|
| 245 |
-
results = {}
|
| 246 |
-
for task_id in task_ids:
|
| 247 |
-
task = manager.get_task_status(task_id)
|
| 248 |
-
if task:
|
| 249 |
-
if task.status == InferenceStatus.COMPLETED:
|
| 250 |
-
results[task.model_name] = task.result
|
| 251 |
-
elif task.status == InferenceStatus.FAILED:
|
| 252 |
-
results[task.model_name] = {"error": task.error}
|
| 253 |
-
|
| 254 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|