Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import multiprocessing as mp | |
| import threading | |
| from typing import Dict, Optional, Sequence, Union | |
| import torch | |
| from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time | |
| from hivemind.moe.server.layers import add_custom_models_from_file | |
| from hivemind.moe.server.runtime import Runtime | |
| from hivemind.proto.runtime_pb2 import CompressionType | |
| from hivemind.utils.logging import get_logger, use_hivemind_log_handler | |
| from src import declare_active_modules, BloomConfig | |
| from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block | |
| from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER | |
| from src.server.backend import TransformerBackend | |
| from src.server.cache import MemoryCache | |
| from src.server.handler import TransformerConnectionHandler | |
| use_hivemind_log_handler("in_root_logger") | |
| logger = get_logger(__file__) | |
| class Server(threading.Thread): | |
| """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT""" | |
| def __init__( | |
| self, | |
| dht: DHT, | |
| module_backends: Dict[str, TransformerBackend], | |
| *, | |
| device: torch.device, | |
| num_connection_handlers: int = 8, | |
| update_period: float = 30, | |
| expiration: Optional[float] = None, | |
| start: bool, | |
| **kwargs, | |
| ): | |
| threading.Thread.__init__(self) | |
| self.dht, self.module_backends, self.update_period = dht, module_backends, update_period | |
| self.conn_handlers = [ | |
| TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers) | |
| ] | |
| self.runtime = Runtime(self.module_backends, device=device, **kwargs) | |
| self.dht_handler_thread = ModuleAnnouncerThread( | |
| self.module_backends, dht, update_period, expiration, daemon=True | |
| ) | |
| self.checkpoint_saver = None # no need to save checkpoints since we do not change model state | |
| if start: | |
| self.run_in_background(await_ready=True) | |
| def run(self): | |
| """ | |
| Starts Server in the current thread. Initializes dht if necessary, starts connection handlers, | |
| runs Runtime (self.runtime) to process incoming requests. | |
| """ | |
| logger.info(f"Serving {len(self.module_backends)} blocks:") | |
| for expert_name, backend in self.module_backends.items(): | |
| num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) | |
| logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") | |
| if not self.dht.is_alive(): | |
| self.dht.run_in_background(await_ready=True) | |
| if self.module_backends: | |
| self.dht_handler_thread.start() | |
| if self.checkpoint_saver is not None: | |
| self.checkpoint_saver.start() | |
| for process in self.conn_handlers: | |
| if not process.is_alive(): | |
| process.start() | |
| process.ready.result() | |
| try: | |
| self.runtime.run() | |
| finally: | |
| self.shutdown() | |
| # noinspection PyMethodOverriding | |
| def create( | |
| cls, | |
| prefix: Optional[str], | |
| converted_model_name_or_path: str, | |
| num_blocks: Optional[int] = None, | |
| block_indices: Optional[str] = None, | |
| num_handlers: Optional[int] = None, | |
| min_batch_size: int = 1, | |
| max_batch_size: int = 4096, | |
| torch_dtype: str = "auto", | |
| cache_size_bytes: Optional[int] = None, | |
| device: Union[str, torch.device] = None, | |
| initial_peers: Sequence[str] = (), | |
| compression=CompressionType.NONE, | |
| stats_report_interval: Optional[int] = None, | |
| custom_module_path=None, | |
| update_period: float = 30, | |
| expiration: Optional[float] = None, | |
| use_auth_token: Optional[str] = None, | |
| *, | |
| start: bool, | |
| **kwargs, | |
| ) -> Server: | |
| """Create a server with one or more bloom blocks. See run_server.py for documentation.""" | |
| if custom_module_path is not None: | |
| add_custom_models_from_file(custom_module_path) | |
| if prefix is None: | |
| prefix = converted_model_name_or_path | |
| assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( | |
| f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " | |
| f"Please specify --prefix manually when starting a server" | |
| ) | |
| logger.info(f"Automatic dht prefix: {prefix}") | |
| assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" | |
| dht = DHT(initial_peers=initial_peers, start=True, **kwargs) | |
| visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()] | |
| logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") | |
| device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| memory_cache = MemoryCache(device, cache_size_bytes) | |
| if isinstance(torch_dtype, str): | |
| torch_dtype = DTYPE_MAP[torch_dtype] | |
| assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" | |
| if block_indices is not None: | |
| try: | |
| first_block_index, last_block_index = block_indices.split(":") | |
| first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index))) | |
| except Exception as e: | |
| logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)") | |
| raise | |
| block_indices = range(first_block_index, last_block_index) | |
| else: | |
| assert num_blocks is not None | |
| block_indices = range(num_blocks) # TODO replace with proper load balancing | |
| block_config = BloomConfig.from_pretrained( | |
| converted_model_name_or_path, use_auth_token=use_auth_token | |
| ) | |
| # initialize modules | |
| blocks = {} | |
| for block_index in block_indices: | |
| module_uid = f"{prefix}.{block_index}" | |
| block = load_pretrained_block( | |
| converted_model_name_or_path, | |
| block_index, | |
| block_config, | |
| torch_dtype=torch_dtype, | |
| use_auth_token=use_auth_token, | |
| ) | |
| for param in block.parameters(): | |
| param.requires_grad = False | |
| blocks[module_uid] = TransformerBackend( | |
| module_uid, | |
| block, | |
| memory_cache=memory_cache, | |
| args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),), | |
| kwargs_schema={}, | |
| outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),), | |
| min_batch_size=min_batch_size, | |
| max_batch_size=max_batch_size, | |
| ) | |
| num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4 | |
| return cls( | |
| dht, | |
| blocks, | |
| num_connection_handlers=num_handlers, | |
| device=device, | |
| stats_report_interval=stats_report_interval, | |
| update_period=update_period, | |
| expiration=expiration, | |
| start=start, | |
| ) | |
| def run_in_background(self, await_ready=True, timeout=None): | |
| """ | |
| Starts Server in a background thread. if await_ready, this method will wait until background server | |
| is ready to process incoming requests or for :timeout: seconds max. | |
| """ | |
| self.start() | |
| if await_ready and not self.ready.wait(timeout=timeout): | |
| raise TimeoutError("Server didn't notify .ready in {timeout} seconds") | |
| def ready(self) -> mp.synchronize.Event: | |
| """ | |
| An event (multiprocessing.Event) that is set when the server is ready to process requests. | |
| Example | |
| ======= | |
| >>> server.start() | |
| >>> server.ready.wait(timeout=10) | |
| >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds") | |
| """ | |
| return self.runtime.ready # mp.Event that is true if self is ready to process batches | |
| def shutdown(self): | |
| """ | |
| Gracefully terminate the server, process-safe. | |
| Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes. | |
| If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL). | |
| """ | |
| self.ready.clear() | |
| for process in self.conn_handlers: | |
| process.terminate() | |
| process.join() | |
| logger.debug("Connection handlers terminated") | |
| if self.module_backends: | |
| self.dht_handler_thread.stop.set() | |
| self.dht_handler_thread.join() | |
| if self.checkpoint_saver is not None: | |
| self.checkpoint_saver.stop.set() | |
| self.checkpoint_saver.join() | |
| self.dht.shutdown() | |
| self.dht.join() | |
| logger.debug(f"Shutting down runtime") | |
| self.runtime.shutdown() | |
| logger.info("Server shutdown succesfully") | |
| class ModuleAnnouncerThread(threading.Thread): | |
| """Periodically announces that this server hosts the specified modules, visible to all DHT peers""" | |
| def __init__( | |
| self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| if expiration is None: | |
| expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) | |
| self.module_backends = module_backends | |
| self.dht = dht | |
| self.update_period = update_period | |
| self.expiration = expiration | |
| self.stop = threading.Event() | |
| def run(self) -> None: | |
| declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration) | |
| while not self.stop.wait(self.update_period): | |
| declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration) | |