Spaces:
Paused
Paused
| import logging | |
| from typing import Any, Optional | |
| import torch | |
| from omegaconf import DictConfig, OmegaConf | |
| from safetensors.torch import load_model | |
| def load_config(cfg_path: str) -> Any: | |
| """ | |
| Load and resolve a configuration file. | |
| Args: | |
| cfg_path (str): The path to the configuration file. | |
| Returns: | |
| Any: The loaded and resolved configuration object. | |
| Raises: | |
| AssertionError: If the loaded configuration is not an instance of DictConfig. | |
| """ | |
| cfg = OmegaConf.load(cfg_path) | |
| OmegaConf.resolve(cfg) | |
| assert isinstance(cfg, DictConfig) | |
| return cfg | |
| def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any: | |
| """ | |
| Parses a configuration dictionary into a structured configuration object. | |
| Args: | |
| cfg_type (Any): The type of the structured configuration object. | |
| cfg (DictConfig): The configuration dictionary to be parsed. | |
| Returns: | |
| Any: The structured configuration object created from the dictionary. | |
| """ | |
| scfg = OmegaConf.structured(cfg_type(**cfg)) | |
| return scfg | |
| def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None: | |
| """ | |
| Load a safetensors checkpoint into a PyTorch model. | |
| The model is updated in place. | |
| Args: | |
| model: PyTorch model to load weights into | |
| ckpt_path: Path to the safetensors checkpoint file | |
| Returns: | |
| None | |
| """ | |
| assert ckpt_path.endswith(".safetensors"), ( | |
| f"Checkpoint path '{ckpt_path}' is not a safetensors file" | |
| ) | |
| load_model(model, ckpt_path) | |