Spaces:
Runtime error
Runtime error
| from typing import Callable, List, Dict | |
| from multiprocessing.pool import ThreadPool | |
| from tqdm import tqdm | |
| from threading import Thread | |
| import asyncio | |
| from functools import wraps | |
| def async_call_func(func): | |
| async def wrapper(*args, **kwargs): | |
| loop = asyncio.get_event_loop() | |
| # Use run_in_executor to run the blocking function in a separate thread | |
| return await loop.run_in_executor(None, func, *args, **kwargs) | |
| return wrapper | |
| def async_call(fn): | |
| def wrapper(*args, **kwargs): | |
| Thread(target=fn, args=args, kwargs=kwargs).start() | |
| return wrapper | |
| def parallel_execution(*args, action: Callable, num_processes=32, print_progress=False, sequential=False, async_return=False, desc=None, **kwargs): | |
| # Copy from EasyVolCap | |
| # Author: Zhen Xu https://github.com/dendenxu | |
| # NOTE: we expect first arg / or kwargs to be distributed | |
| # NOTE: print_progress arg is reserved | |
| def get_length(args: List, kwargs: Dict): | |
| for a in args: | |
| if isinstance(a, list): | |
| return len(a) | |
| for v in kwargs.values(): | |
| if isinstance(v, list): | |
| return len(v) | |
| raise NotImplementedError | |
| def get_action_args(length: int, args: List, kwargs: Dict, i: int): | |
| action_args = [(arg[i] if isinstance(arg, list) and len( | |
| arg) == length else arg) for arg in args] | |
| # TODO: Support all types of iterable | |
| action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len( | |
| kwargs[key]) == length else kwargs[key]) for key in kwargs} | |
| return action_args, action_kwargs | |
| if not sequential: | |
| # Create ThreadPool | |
| pool = ThreadPool(processes=num_processes) | |
| # Spawn threads | |
| results = [] | |
| asyncs = [] | |
| length = get_length(args, kwargs) | |
| for i in range(length): | |
| action_args, action_kwargs = get_action_args( | |
| length, args, kwargs, i) | |
| async_result = pool.apply_async(action, action_args, action_kwargs) | |
| asyncs.append(async_result) | |
| # Join threads and get return values | |
| if not async_return: | |
| for async_result in tqdm(asyncs, desc=desc, disable=not print_progress): | |
| # will sync the corresponding thread | |
| results.append(async_result.get()) | |
| pool.close() | |
| pool.join() | |
| return results | |
| else: | |
| return pool | |
| else: | |
| results = [] | |
| length = get_length(args, kwargs) | |
| for i in tqdm(range(length), desc=desc, disable=not print_progress): | |
| action_args, action_kwargs = get_action_args( | |
| length, args, kwargs, i) | |
| async_result = action(*action_args, **action_kwargs) | |
| results.append(async_result) | |
| return results | |