Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.distributed | |
| from typing import List, Optional | |
| from accelerate import init_empty_weights | |
| from safetensors import safe_open | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| AutoConfig, | |
| ) | |
| from transformers.models.gpt_neox.parallel_layers import ( | |
| TensorParallelColumnLinear, | |
| TensorParallelEmbedding, | |
| TensorParallelRowLinear, | |
| ) | |
| from text_generation_server.models import CausalLM | |
| from text_generation_server.utils import ( | |
| initialize_torch_distributed, | |
| weight_files, | |
| ) | |
| HAS_BITS_AND_BYTES = True | |
| try: | |
| import bitsandbytes as bnb | |
| from bitsandbytes.nn import Int8Params | |
| except Exception as e: | |
| HAS_BITS_AND_BYTES = False | |
| class GPTNeoxSharded(CausalLM): | |
| def __init__( | |
| self, model_id: str, revision: Optional[str] = None, quantize: bool = False | |
| ): | |
| self.process_group, self.rank, self.world_size = initialize_torch_distributed() | |
| self.master = self.rank == 0 | |
| if torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{self.rank}") | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
| else: | |
| device = torch.device("cpu") | |
| dtype = torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, revision=revision, padding_side="left", truncation_side="left" | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| config = AutoConfig.from_pretrained( | |
| model_id, revision=revision, tp_parallel=True | |
| ) | |
| torch.distributed.barrier(group=self.process_group) | |
| filenames = weight_files(model_id, revision=revision, extension=".safetensors") | |
| with init_empty_weights(): | |
| model = AutoModelForCausalLM.from_config(config) | |
| torch.distributed.barrier(group=self.process_group) | |
| self.load_weights( | |
| model, | |
| filenames, | |
| quantize=quantize, | |
| device=device, | |
| dtype=dtype, | |
| rank=self.rank, | |
| world_size=self.world_size, | |
| ) | |
| self.model = model.eval() | |
| torch.distributed.barrier(group=self.process_group) | |
| super(CausalLM, self).__init__( | |
| tokenizer=tokenizer, | |
| requires_padding=True, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def load_weights( | |
| model, | |
| filenames: List[str], | |
| quantize: bool, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| rank: int, | |
| world_size: int, | |
| ): | |
| parameters = dict(model.named_parameters()) | |
| for file in filenames: | |
| with safe_open( | |
| file, framework="pt", device=str(device) if not quantize else "cpu" | |
| ) as f: | |
| for name in f.keys(): | |
| module_name, param_name = name.rsplit(".", 1) | |
| module = model.get_submodule(module_name) | |
| current_parameter_tensor = parameters.get(name, None) | |
| slice_ = f.get_slice(name) | |
| if isinstance(module, TensorParallelColumnLinear): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| elif isinstance(module, TensorParallelRowLinear): | |
| if param_name == "weight": | |
| size = slice_.get_shape()[1] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[:, start:stop] | |
| else: | |
| tensor = slice_[:] | |
| # XXX: Hack for Rowlinear to add the bias only once. | |
| if rank != 0: | |
| tensor = torch.zeros_like(tensor) | |
| elif isinstance(module, TensorParallelEmbedding): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| else: | |
| try: | |
| tensor = slice_[:] | |
| except: | |
| tensor = f.get_tensor(name) | |
| if ( | |
| current_parameter_tensor is not None | |
| and current_parameter_tensor.shape != tensor.shape | |
| ): | |
| raise ValueError( | |
| f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" | |
| ) | |
| tensor = tensor.contiguous().to(dtype) | |
| if quantize: | |
| if not HAS_BITS_AND_BYTES: | |
| raise ImportError( | |
| "bitsandbytes is not available on your machine either because it is not installed " | |
| "or you don't have a GPU.\n" | |
| "You can install it with `pip install bitsandbytes`." | |
| ) | |
| if ( | |
| type(module) | |
| in [TensorParallelRowLinear, TensorParallelColumnLinear] | |
| and param_name == "weight" | |
| ): | |
| tensor = Int8Params( | |
| tensor, | |
| has_fp16_weights=False, | |
| requires_grad=False, | |
| ).to(device) | |
| state = bnb.MatmulLtState() | |
| state.threshold = 6.0 | |
| state.has_fp16_weights = False | |
| state.memory_efficient_backward = False | |
| state.use_pool = True | |
| state.CB = tensor.CB | |
| state.SCB = tensor.SCB | |
| tensor.CB = None | |
| tensor.SCB = None | |
| def replace_linear(state): | |
| def linear(input, weight, bias): | |
| out = bnb.matmul( | |
| input, | |
| weight, | |
| state=state, | |
| threshold=state.threshold, | |
| bias=bias, | |
| ) | |
| if state.CB is not None: | |
| # we converted 8-bit row major to turing/ampere format | |
| # in the first inference pass | |
| # we no longer need the row-major weight | |
| del state.CB | |
| weight.data = state.CxB | |
| return out | |
| return linear | |
| module.linear = replace_linear(state) | |
| else: | |
| tensor = tensor.to(device) | |
| if current_parameter_tensor is not None: | |
| module._parameters[param_name] = tensor | |
| else: | |
| module._buffers[param_name] = tensor | |
| def forward( | |
| self, input_ids, attention_mask, position_ids, past_key_values: Optional = None | |
| ): | |
| if self.model.gpt_neox.tp_embeddings: | |
| outputs = self.model.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| use_cache=True, | |
| ) | |
| # Logits are sharded, so we need to gather them | |
| logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] | |
| torch.distributed.all_gather( | |
| logits, outputs.logits, group=self.process_group | |
| ) | |
| logits = torch.cat(logits, dim=2) | |
| return logits, outputs.past_key_values | |
| # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard | |
| else: | |
| return super(GPTNeoxSharded, self).forward( | |
| input_ids, attention_mask, position_ids, past_key_values | |
| ) | |