Spaces:
Paused
Paused
| import contextlib | |
| import functools | |
| import hashlib | |
| import logging | |
| import os | |
| import requests | |
| import torch | |
| import tqdm | |
| from TTS.tts.layers.bark.model import GPT, GPTConfig | |
| from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig | |
| if ( | |
| torch.cuda.is_available() | |
| and hasattr(torch.cuda, "amp") | |
| and hasattr(torch.cuda.amp, "autocast") | |
| and torch.cuda.is_bf16_supported() | |
| ): | |
| autocast = functools.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) | |
| else: | |
| def autocast(): | |
| yield | |
| # hold models in global scope to lazy load | |
| logger = logging.getLogger(__name__) | |
| if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): | |
| logger.warning( | |
| "torch version does not support flash attention. You will get significantly faster" | |
| + " inference speed by upgrade torch to newest version / nightly." | |
| ) | |
| def _md5(fname): | |
| hash_md5 = hashlib.md5() | |
| with open(fname, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| return hash_md5.hexdigest() | |
| def _download(from_s3_path, to_local_path, CACHE_DIR): | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| response = requests.get(from_s3_path, stream=True) | |
| total_size_in_bytes = int(response.headers.get("content-length", 0)) | |
| block_size = 1024 # 1 Kibibyte | |
| progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) | |
| with open(to_local_path, "wb") as file: | |
| for data in response.iter_content(block_size): | |
| progress_bar.update(len(data)) | |
| file.write(data) | |
| progress_bar.close() | |
| if total_size_in_bytes not in [0, progress_bar.n]: | |
| raise ValueError("ERROR, something went wrong") | |
| class InferenceContext: | |
| def __init__(self, benchmark=False): | |
| # we can't expect inputs to be the same length, so disable benchmarking by default | |
| self._chosen_cudnn_benchmark = benchmark | |
| self._cudnn_benchmark = None | |
| def __enter__(self): | |
| self._cudnn_benchmark = torch.backends.cudnn.benchmark | |
| torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark | |
| def __exit__(self, exc_type, exc_value, exc_traceback): | |
| torch.backends.cudnn.benchmark = self._cudnn_benchmark | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| def inference_mode(): | |
| with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): | |
| yield | |
| def clear_cuda_cache(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def load_model(ckpt_path, device, config, model_type="text"): | |
| logger.info(f"loading {model_type} model from {ckpt_path}...") | |
| if device == "cpu": | |
| logger.warning("No GPU being used. Careful, Inference might be extremely slow!") | |
| if model_type == "text": | |
| ConfigClass = GPTConfig | |
| ModelClass = GPT | |
| elif model_type == "coarse": | |
| ConfigClass = GPTConfig | |
| ModelClass = GPT | |
| elif model_type == "fine": | |
| ConfigClass = FineGPTConfig | |
| ModelClass = FineGPT | |
| else: | |
| raise NotImplementedError() | |
| if ( | |
| not config.USE_SMALLER_MODELS | |
| and os.path.exists(ckpt_path) | |
| and _md5(ckpt_path) != config.REMOTE_MODEL_PATHS[model_type]["checksum"] | |
| ): | |
| logger.warning(f"found outdated {model_type} model, removing...") | |
| os.remove(ckpt_path) | |
| if not os.path.exists(ckpt_path): | |
| logger.info(f"{model_type} model not found, downloading...") | |
| _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| # this is a hack | |
| model_args = checkpoint["model_args"] | |
| if "input_vocab_size" not in model_args: | |
| model_args["input_vocab_size"] = model_args["vocab_size"] | |
| model_args["output_vocab_size"] = model_args["vocab_size"] | |
| del model_args["vocab_size"] | |
| gptconf = ConfigClass(**checkpoint["model_args"]) | |
| if model_type == "text": | |
| config.semantic_config = gptconf | |
| elif model_type == "coarse": | |
| config.coarse_config = gptconf | |
| elif model_type == "fine": | |
| config.fine_config = gptconf | |
| model = ModelClass(gptconf) | |
| state_dict = checkpoint["model"] | |
| # fixup checkpoint | |
| unwanted_prefix = "_orig_mod." | |
| for k, _ in list(state_dict.items()): | |
| if k.startswith(unwanted_prefix): | |
| state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) | |
| extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) | |
| extra_keys = set(k for k in extra_keys if not k.endswith(".attn.bias")) | |
| missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) | |
| missing_keys = set(k for k in missing_keys if not k.endswith(".attn.bias")) | |
| if len(extra_keys) != 0: | |
| raise ValueError(f"extra keys found: {extra_keys}") | |
| if len(missing_keys) != 0: | |
| raise ValueError(f"missing keys: {missing_keys}") | |
| model.load_state_dict(state_dict, strict=False) | |
| n_params = model.get_num_params() | |
| val_loss = checkpoint["best_val_loss"].item() | |
| logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") | |
| model.eval() | |
| model.to(device) | |
| del checkpoint, state_dict | |
| clear_cuda_cache() | |
| return model, config | |