Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.distributed | |
| from accelerate import init_empty_weights | |
| from opentelemetry import trace | |
| from pathlib import Path | |
| from safetensors import safe_open | |
| from transformers import AutoConfig | |
| from transformers.models.llama import LlamaTokenizer | |
| from typing import Optional, List | |
| from text_generation_server.models import FlashCausalLM | |
| from text_generation_server.models.custom_modeling.flash_llama_modeling import ( | |
| FlashLlamaForCausalLM, | |
| TensorParallelEmbedding, | |
| TensorParallelRowLinear, | |
| TensorParallelColumnLinear, | |
| ) | |
| from text_generation_server.utils import ( | |
| initialize_torch_distributed, | |
| weight_files, | |
| download_weights, | |
| weight_hub_files, | |
| LocalEntryNotFoundError, | |
| ) | |
| tracer = trace.get_tracer(__name__) | |
| class FlashLlama(FlashCausalLM): | |
| def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): | |
| self.past_pad = None | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| raise NotImplementedError("FlashLlama is only available on GPU") | |
| tokenizer = LlamaTokenizer.from_pretrained( | |
| model_id, | |
| revision=revision, | |
| padding_side="left", | |
| truncation_side="left", | |
| ) | |
| config = AutoConfig.from_pretrained( | |
| model_id, | |
| revision=revision, | |
| ) | |
| # We do not use from_pretrained as we modified the model internal module layout | |
| try: | |
| filenames = weight_files(model_id, revision, ".bin") | |
| # Local files not found | |
| except LocalEntryNotFoundError: | |
| hub_files = weight_hub_files(model_id, revision, ".bin") | |
| filenames = download_weights(hub_files, model_id, revision) | |
| with init_empty_weights(): | |
| model = FlashLlamaForCausalLM(config) | |
| self.load_weights(model, filenames, quantize, device, dtype) | |
| self.model = model.eval().to(device) | |
| super(FlashCausalLM, self).__init__( | |
| tokenizer=tokenizer, | |
| requires_padding=False, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def load_weights( | |
| model, | |
| filenames: List[Path], | |
| quantize: bool, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ): | |
| for filename in filenames: | |
| state_dict = torch.load(filename, map_location="cpu") | |
| for key, value in state_dict.items(): | |
| value = value.to(device if not quantize else "cpu").to(dtype) | |
| layer_name = ".".join(key.split(".")[:4]) | |
| # Fused qkv | |
| if "q_proj" in key or "k_proj" in key or "v_proj" in key: | |
| final_key = layer_name + ".query_key_value.weight" | |
| # Fused gate and up projs | |
| elif "gate_proj" in key or "up_proj" in key: | |
| final_key = layer_name + ".gate_up_proj.weight" | |
| else: | |
| final_key = key | |
| module_name, param_name = final_key.rsplit(".", 1) | |
| module = model.get_submodule(module_name) | |
| try: | |
| current_parameter_tensor = module._parameters[param_name] | |
| except KeyError: | |
| current_parameter_tensor = None | |
| if current_parameter_tensor is not None: | |
| if current_parameter_tensor.device == torch.device("meta"): | |
| # Init qkv | |
| if "query_key_value" in final_key: | |
| module._parameters[param_name] = value.new_empty( | |
| (value.shape[0] * 3, value.shape[1]) | |
| ) | |
| # Init gate and up proj | |
| elif "gate_up_proj" in final_key: | |
| module._parameters[param_name] = value.new_empty( | |
| (value.shape[0] * 2, value.shape[1]) | |
| ) | |
| # Copy to correct slice | |
| if "q_proj" in key: | |
| module._parameters[param_name][: value.shape[0]] = value | |
| elif "k_proj" in key: | |
| module._parameters[param_name][ | |
| value.shape[0] : value.shape[0] * 2 | |
| ] = value | |
| elif "v_proj" in key: | |
| module._parameters[param_name][value.shape[0] * 2 :] = value | |
| elif "gate_proj" in key: | |
| module._parameters[param_name][: value.shape[0]] = value | |
| elif "up_proj" in key: | |
| module._parameters[param_name][value.shape[0] :] = value | |
| else: | |
| if current_parameter_tensor.shape != value.shape: | |
| raise ValueError( | |
| f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" | |
| ) | |
| module._parameters[param_name] = value | |
| else: | |
| module._buffers[param_name] = value | |
| del value | |
| torch.cuda.empty_cache() | |
| model.post_load_weights(quantize) | |
| class FlashLlamaSharded(FlashLlama): | |
| def __init__( | |
| self, model_id: str, revision: Optional[str] = None, quantize: bool = False | |
| ): | |
| self.past_pad = None | |
| self.process_group, self.rank, self.world_size = initialize_torch_distributed() | |
| self.master = self.rank == 0 | |
| if torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{self.rank}") | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| raise NotImplementedError("FlashLlama is only available on GPU") | |
| tokenizer = LlamaTokenizer.from_pretrained( | |
| model_id, | |
| revision=revision, | |
| padding_side="left", | |
| truncation_side="left", | |
| ) | |
| config = AutoConfig.from_pretrained( | |
| model_id, | |
| revision=revision, | |
| ) | |
| torch.distributed.barrier(group=self.process_group) | |
| filenames = weight_files(model_id, revision=revision, extension=".safetensors") | |
| with init_empty_weights(): | |
| model = FlashLlamaForCausalLM(config, process_group=self.process_group) | |
| torch.distributed.barrier(group=self.process_group) | |
| self.load_weights( | |
| model, | |
| filenames, | |
| quantize=quantize, | |
| device=device, | |
| dtype=dtype, | |
| rank=self.rank, | |
| world_size=self.world_size, | |
| ) | |
| self.model = model.eval().to(device) | |
| torch.distributed.barrier(group=self.process_group) | |
| super(FlashCausalLM, self).__init__( | |
| tokenizer=tokenizer, | |
| requires_padding=False, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def load_weights( | |
| model, | |
| filenames: List[str], | |
| quantize: bool, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| rank: int, | |
| world_size: int, | |
| ): | |
| for file in filenames: | |
| with safe_open( | |
| file, framework="pt", device=str(device) if not quantize else "cpu" | |
| ) as f: | |
| for name in f.keys(): | |
| slice_ = f.get_slice(name) | |
| layer_name = ".".join(name.split(".")[:4]) | |
| # Fused qkv | |
| if "q_proj" in name or "k_proj" in name or "v_proj" in name: | |
| final_name = layer_name + ".query_key_value.weight" | |
| # Fused gate and up projs | |
| elif "gate_proj" in name or "up_proj" in name: | |
| final_name = layer_name + ".gate_up_proj.weight" | |
| else: | |
| final_name = name | |
| module_name, param_name = final_name.rsplit(".", 1) | |
| module = model.get_submodule(module_name) | |
| if isinstance(module, TensorParallelColumnLinear): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| elif isinstance(module, TensorParallelRowLinear): | |
| size = slice_.get_shape()[1] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[:, start:stop] | |
| elif isinstance(module, TensorParallelEmbedding): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| elif name == "lm_head.weight" and model.model.tp_embeddings: | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| else: | |
| try: | |
| tensor = slice_[:] | |
| except: | |
| tensor = f.get_tensor(name) | |
| tensor = tensor.contiguous().to(dtype) | |
| try: | |
| current_parameter_tensor = module._parameters[param_name] | |
| except KeyError: | |
| current_parameter_tensor = None | |
| if current_parameter_tensor is not None: | |
| if current_parameter_tensor.device == torch.device("meta"): | |
| # Init qkv | |
| if "query_key_value" in final_name: | |
| module._parameters[param_name] = tensor.new_empty( | |
| (tensor.shape[0] * 3, tensor.shape[1]) | |
| ) | |
| # Init gate and up proj | |
| elif "gate_up_proj" in final_name: | |
| module._parameters[param_name] = tensor.new_empty( | |
| (tensor.shape[0] * 2, tensor.shape[1]) | |
| ) | |
| # Init gate and up proj | |
| if "q_proj" in name: | |
| module._parameters[param_name][: tensor.shape[0]] = tensor | |
| elif "k_proj" in name: | |
| module._parameters[param_name][ | |
| tensor.shape[0] : tensor.shape[0] * 2 | |
| ] = tensor | |
| elif "v_proj" in name: | |
| module._parameters[param_name][ | |
| tensor.shape[0] * 2 : | |
| ] = tensor | |
| elif "gate_proj" in name: | |
| module._parameters[param_name][: tensor.shape[0]] = tensor | |
| elif "up_proj" in name: | |
| module._parameters[param_name][tensor.shape[0] :] = tensor | |
| else: | |
| if current_parameter_tensor.shape != tensor.shape: | |
| raise ValueError( | |
| f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" | |
| ) | |
| module._parameters[param_name] = tensor | |
| else: | |
| module._buffers[param_name] = tensor | |
| torch.cuda.empty_cache() | |
| model.post_load_weights(quantize) | |