Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.distributed | |
| from accelerate import init_empty_weights | |
| from opentelemetry import trace | |
| from safetensors import safe_open | |
| from transformers import AutoTokenizer, AutoConfig | |
| from typing import Optional, List | |
| from text_generation_server.models import FlashCausalLM | |
| from text_generation_server.models.custom_modeling.flash_neox_modeling import ( | |
| FlashGPTNeoXForCausalLM, | |
| TensorParallelEmbedding, | |
| TensorParallelRowLinear, | |
| TensorParallelColumnLinear, | |
| ) | |
| from text_generation_server.utils import ( | |
| initialize_torch_distributed, | |
| weight_files, | |
| ) | |
| tracer = trace.get_tracer(__name__) | |
| class FlashNeoX(FlashCausalLM): | |
| def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): | |
| super(FlashNeoX, self).__init__( | |
| FlashGPTNeoXForCausalLM, model_id, revision, quantize | |
| ) | |
| class FlashNeoXSharded(FlashNeoX): | |
| def __init__( | |
| self, model_id: str, revision: Optional[str] = None, quantize: bool = False | |
| ): | |
| self.past_pad = None | |
| self.process_group, self.rank, self.world_size = initialize_torch_distributed() | |
| self.master = self.rank == 0 | |
| if torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{self.rank}") | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| raise NotImplementedError("FlashNeoX is only available on GPU") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, revision=revision, padding_side="left", truncation_side="left" | |
| ) | |
| config = AutoConfig.from_pretrained( | |
| model_id, | |
| revision=revision, | |
| ) | |
| torch.distributed.barrier(group=self.process_group) | |
| filenames = weight_files(model_id, revision=revision, extension=".safetensors") | |
| with init_empty_weights(): | |
| model = FlashGPTNeoXForCausalLM(config, self.process_group) | |
| torch.distributed.barrier(group=self.process_group) | |
| self.load_weights( | |
| model, | |
| filenames, | |
| quantize=quantize, | |
| device=device, | |
| dtype=dtype, | |
| rank=self.rank, | |
| world_size=self.world_size, | |
| ) | |
| self.model = model.eval().to(device) | |
| torch.distributed.barrier(group=self.process_group) | |
| super(FlashCausalLM, self).__init__( | |
| tokenizer=tokenizer, | |
| requires_padding=False, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def load_weights( | |
| model, | |
| filenames: List[str], | |
| quantize: bool, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| rank: int, | |
| world_size: int, | |
| ): | |
| parameters = dict(model.named_parameters()) | |
| for file in filenames: | |
| with safe_open( | |
| file, framework="pt", device=str(device) if not quantize else "cpu" | |
| ) as f: | |
| for name in f.keys(): | |
| module_name, param_name = name.rsplit(".", 1) | |
| module = model.get_submodule(module_name) | |
| current_parameter_tensor = parameters.get(name, None) | |
| slice_ = f.get_slice(name) | |
| if isinstance(module, TensorParallelColumnLinear): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| elif isinstance(module, TensorParallelRowLinear): | |
| if param_name == "weight": | |
| size = slice_.get_shape()[1] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[:, start:stop] | |
| else: | |
| tensor = slice_[:] | |
| # XXX: Hack for Rowlinear to add the bias only once. | |
| if rank != 0: | |
| tensor = torch.zeros_like(tensor) | |
| elif isinstance(module, TensorParallelEmbedding): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| else: | |
| try: | |
| tensor = slice_[:] | |
| except: | |
| tensor = f.get_tensor(name) | |
| if ( | |
| current_parameter_tensor is not None | |
| and current_parameter_tensor.shape != tensor.shape | |
| ): | |
| raise ValueError( | |
| f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" | |
| ) | |
| tensor = tensor.contiguous().to(dtype) | |
| if current_parameter_tensor is not None: | |
| module._parameters[param_name] = tensor | |
| else: | |
| module._buffers[param_name] = tensor | |
| model.post_load_weights(quantize) | |