Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.distributed | |
| from accelerate import init_empty_weights | |
| from opentelemetry import trace | |
| from safetensors import safe_open | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, GPT2Config | |
| from typing import Optional, List | |
| from text_generation_server.models import FlashCausalLM | |
| from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( | |
| FlashSantacoderForCausalLM, | |
| TensorParallelRowLinear, | |
| TensorParallelColumnLinear, | |
| TensorParallelEmbedding, | |
| ) | |
| from text_generation_server.utils import ( | |
| initialize_torch_distributed, | |
| weight_files, | |
| download_weights, | |
| weight_hub_files, | |
| LocalEntryNotFoundError, | |
| ) | |
| tracer = trace.get_tracer(__name__) | |
| class FlashSantacoder(FlashCausalLM): | |
| def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): | |
| self.past_pad = None | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| raise NotImplementedError("FlashSantacoder is only available on GPU") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, revision=revision, padding_side="left", truncation_side="left" | |
| ) | |
| config = GPT2Config.from_pretrained( | |
| model_id, | |
| revision=revision, | |
| ) | |
| # We do not use from_pretrained as we modified the model internal module layout | |
| try: | |
| filenames = weight_files(model_id, revision, ".bin") | |
| # Local files not found | |
| except LocalEntryNotFoundError: | |
| hub_files = weight_hub_files(model_id, revision, ".bin") | |
| filenames = download_weights(hub_files, model_id, revision) | |
| with init_empty_weights(): | |
| model = FlashSantacoderForCausalLM(config) | |
| self.load_weights( | |
| model, | |
| filenames, | |
| quantize, | |
| device, | |
| dtype, | |
| config.architectures[0].startswith("GPT2"), | |
| ) | |
| self.model = model.eval().to(device) | |
| super(FlashCausalLM, self).__init__( | |
| tokenizer=tokenizer, | |
| requires_padding=False, | |
| dtype=dtype, | |
| device=device, | |
| decode_buffer=1, | |
| ) | |
| def load_weights( | |
| model: FlashSantacoderForCausalLM, | |
| filenames: List[Path], | |
| quantize: bool, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| transpose: bool, | |
| ): | |
| for filename in filenames: | |
| state_dict = torch.load(filename, map_location="cpu") | |
| for key, value in state_dict.items(): | |
| value = value.to(device if not quantize else "cpu").to(dtype) | |
| layer_name = ".".join(key.split(".")[:4]) | |
| # Fused qkv | |
| if "q_attn.weight" in key or "kv_attn.weight" in key: | |
| final_key = layer_name + ".c_attn.weight" | |
| elif "q_attn.bias" in key or "kv_attn.bias" in key: | |
| final_key = layer_name + ".c_attn.bias" | |
| else: | |
| final_key = key | |
| module_name, param_name = final_key.rsplit(".", 1) | |
| module = model.get_submodule(module_name) | |
| try: | |
| current_parameter_tensor = module._parameters[param_name] | |
| except KeyError: | |
| current_parameter_tensor = None | |
| if current_parameter_tensor is not None: | |
| if transpose and ( | |
| "c_fc.weight" in key | |
| or "c_proj.weight" in key | |
| or "q_attn.weight" in key | |
| or "kv_attn.weight" in key | |
| or "c_attn.weight" in key | |
| ): | |
| # Tranpose as we use nn.Linear instead of Conv1D | |
| value = value.T | |
| if current_parameter_tensor.device == torch.device("meta"): | |
| # Init qkv | |
| if "c_attn.weight" in final_key: | |
| module._parameters[param_name] = value.new_empty( | |
| ( | |
| model.transformer.head_size | |
| * (model.transformer.num_heads + 2), | |
| value.shape[1], | |
| ) | |
| ) | |
| elif "c_attn.bias" in final_key: | |
| module._parameters[param_name] = value.new_empty( | |
| ( | |
| model.transformer.head_size | |
| * (model.transformer.num_heads + 2) | |
| ) | |
| ) | |
| # Copy to correct slice | |
| if "q_attn.weight" in key: | |
| module._parameters[param_name][: value.shape[0]] = value | |
| elif "q_attn.bias" in key: | |
| module._parameters[param_name][: value.shape[0]] = value | |
| elif "kv_attn.weight" in key: | |
| module._parameters[param_name][ | |
| model.transformer.head_size * model.transformer.num_heads : | |
| ] = value | |
| elif "kv_attn.bias" in key: | |
| module._parameters[param_name][ | |
| model.transformer.head_size * model.transformer.num_heads : | |
| ] = value | |
| else: | |
| if current_parameter_tensor.shape != value.shape: | |
| raise ValueError( | |
| f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" | |
| ) | |
| module._parameters[param_name] = value | |
| else: | |
| module._buffers[param_name] = value | |
| del value | |
| torch.cuda.empty_cache() | |
| model.post_load_weights(quantize) | |
| def decode(self, generated_ids: List[int]) -> str: | |
| # Do not skip special tokens as they are used for custom parsing rules of the generated text | |
| return self.tokenizer.decode( | |
| generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False | |
| ) | |
| class FlashSantacoderSharded(FlashSantacoder): | |
| def __init__( | |
| self, model_id: str, revision: Optional[str] = None, quantize: bool = False | |
| ): | |
| self.past_pad = None | |
| self.process_group, self.rank, self.world_size = initialize_torch_distributed() | |
| self.master = self.rank == 0 | |
| if torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{self.rank}") | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| raise NotImplementedError("FlashSantacoderSharded is only available on GPU") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, revision=revision, padding_side="left", truncation_side="left" | |
| ) | |
| config = GPT2Config.from_pretrained( | |
| model_id, | |
| revision=revision, | |
| ) | |
| torch.distributed.barrier(group=self.process_group) | |
| filenames = weight_files(model_id, revision=revision, extension=".safetensors") | |
| with init_empty_weights(): | |
| model = FlashSantacoderForCausalLM(config, self.process_group) | |
| torch.distributed.barrier(group=self.process_group) | |
| self.load_weights( | |
| model, | |
| filenames, | |
| quantize=quantize, | |
| device=device, | |
| dtype=dtype, | |
| rank=self.rank, | |
| world_size=self.world_size, | |
| transpose=config.architectures[0].startswith("GPT2"), | |
| ) | |
| self.model = model.eval().to(device) | |
| torch.distributed.barrier(group=self.process_group) | |
| super(FlashCausalLM, self).__init__( | |
| tokenizer=tokenizer, | |
| requires_padding=False, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def load_weights( | |
| model, | |
| filenames: List[str], | |
| quantize: bool, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| rank: int, | |
| world_size: int, | |
| transpose: bool, | |
| ): | |
| for file in filenames: | |
| with safe_open( | |
| file, framework="pt", device=str(device) if not quantize else "cpu" | |
| ) as f: | |
| for key in f.keys(): | |
| slice_ = f.get_slice(key) | |
| layer_name = ".".join(key.split(".")[:4]) | |
| # Fused qkv | |
| if "q_attn.weight" in key or "kv_attn.weight" in key: | |
| final_key = layer_name + ".c_attn.weight" | |
| elif "q_attn.bias" in key or "kv_attn.bias" in key: | |
| final_key = layer_name + ".c_attn.bias" | |
| else: | |
| final_key = key | |
| module_name, param_name = final_key.rsplit(".", 1) | |
| module = model.get_submodule(module_name) | |
| if isinstance(module, TensorParallelColumnLinear): | |
| dim = 1 if transpose and "weight" in param_name else 0 | |
| size = slice_.get_shape()[dim] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = ( | |
| slice_[start:stop] if dim == 0 else slice_[:, start:stop] | |
| ) | |
| elif isinstance(module, TensorParallelRowLinear): | |
| if param_name == "weight": | |
| dim = 0 if transpose else 1 | |
| size = slice_.get_shape()[dim] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = ( | |
| slice_[start:stop] | |
| if dim == 0 | |
| else slice_[:, start:stop] | |
| ) | |
| else: | |
| tensor = slice_[:] | |
| # XXX: Hack for Rowlinear to add the bias only once. | |
| if rank != 0: | |
| tensor = torch.zeros_like(tensor) | |
| elif isinstance(module, TensorParallelEmbedding): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| elif key == "lm_head.weight" and model.transformer.tp_embeddings: | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| else: | |
| try: | |
| tensor = slice_[:] | |
| except: | |
| tensor = f.get_tensor(key) | |
| tensor = tensor.contiguous().to(dtype) | |
| try: | |
| current_parameter_tensor = module._parameters[param_name] | |
| except KeyError: | |
| current_parameter_tensor = None | |
| if current_parameter_tensor is not None: | |
| if transpose and ( | |
| "c_fc.weight" in key | |
| or "c_proj.weight" in key | |
| or "q_attn.weight" in key | |
| or "kv_attn.weight" in key | |
| or "c_attn.weight" in key | |
| ): | |
| # Tranpose as we use nn.Linear instead of Conv1D | |
| tensor = tensor.T | |
| if current_parameter_tensor.device == torch.device("meta"): | |
| # Init qkv | |
| if "c_attn.weight" in final_key: | |
| module._parameters[param_name] = tensor.new_empty( | |
| ( | |
| model.transformer.head_size | |
| * (model.transformer.num_heads + 2), | |
| tensor.shape[1], | |
| ) | |
| ) | |
| elif "c_attn.bias" in final_key: | |
| module._parameters[param_name] = tensor.new_empty( | |
| ( | |
| model.transformer.head_size | |
| * (model.transformer.num_heads + 2) | |
| ) | |
| ) | |
| # Copy to correct slice | |
| if "q_attn" in key: | |
| size = tensor.shape[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = tensor[start:stop] | |
| module._parameters[param_name][: tensor.shape[0]] = tensor | |
| elif "kv_attn.weight" in key: | |
| module._parameters[param_name][ | |
| model.transformer.head_size | |
| * model.transformer.num_heads : | |
| ] = tensor | |
| elif "kv_attn.bias" in key: | |
| module._parameters[param_name][ | |
| model.transformer.head_size | |
| * model.transformer.num_heads : | |
| ] = tensor | |
| elif "c_attn" in key: | |
| # Slice q_tensor by shard | |
| q_tensor = tensor[: -2 * model.transformer.head_size] | |
| block_size = q_tensor.shape[0] // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| q_tensor = q_tensor[start:stop] | |
| module._parameters[param_name][ | |
| : q_tensor.shape[0] | |
| ] = q_tensor | |
| # Kv tensor is copied for every shard | |
| kv_tensor = tensor[-2 * model.transformer.head_size :] | |
| module._parameters[param_name][ | |
| q_tensor.shape[0] : | |
| ] = kv_tensor | |
| else: | |
| if current_parameter_tensor.shape != tensor.shape: | |
| raise ValueError( | |
| f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" | |
| ) | |
| module._parameters[param_name] = tensor | |
| else: | |
| module._buffers[param_name] = tensor | |
| torch.cuda.empty_cache() | |
| model.post_load_weights(quantize) | |