Spaces:
Runtime error
Runtime error
| import re | |
| import torch | |
| import torch.distributed | |
| from typing import List, Optional, Type, Tuple | |
| from accelerate import init_empty_weights | |
| from safetensors import safe_open | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| AutoConfig, | |
| PreTrainedTokenizerBase, | |
| ) | |
| from transformers.models.opt.parallel_layers import ( | |
| TensorParallelColumnLinear, | |
| TensorParallelEmbedding, | |
| TensorParallelRowLinear, | |
| ) | |
| from text_generation_server.models import CausalLM | |
| from text_generation_server.models.causal_lm import CausalLMBatch | |
| from text_generation_server.pb import generate_pb2 | |
| from text_generation_server.models.opt import OPT | |
| from text_generation_server.utils import ( | |
| NextTokenChooser, | |
| StoppingCriteria, | |
| initialize_torch_distributed, | |
| weight_files, | |
| ) | |
| HAS_BITS_AND_BYTES = True | |
| try: | |
| import bitsandbytes as bnb | |
| from bitsandbytes.nn import Int8Params | |
| except Exception as e: | |
| HAS_BITS_AND_BYTES = False | |
| # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py | |
| # we split individual characters inside special tokens like [START_DNA] | |
| CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") | |
| # token added to implement a custom sequence tokenization. This token is added at | |
| # corpus cleaning step and removed in pretokenization. The digits are added to increase the chance | |
| # that they do not occur in the corpus. The digits are escaped so that the token does not appear | |
| # literally in the source code in case we ever include it in the training data. | |
| SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" | |
| def _insert_split_marker(m: re.Match): | |
| """ | |
| Applies split marker based on a regex match of special tokens such as | |
| [START_DNA]. | |
| Parameters | |
| ---------- | |
| n : str | |
| Input text to split | |
| Returns | |
| ---------- | |
| str - the text with the split token added | |
| """ | |
| start_token, _, sequence, end_token = m.groups() | |
| sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) | |
| return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" | |
| def escape_custom_split_sequence(text): | |
| """ | |
| Applies custom splitting to the text for GALILEO's tokenization | |
| Parameters | |
| ---------- | |
| text : str | |
| Input text to split | |
| Returns | |
| ---------- | |
| str - the text with the split token added | |
| """ | |
| return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) | |
| # END CREDIT | |
| class GalacticaCausalLMBatch(CausalLMBatch): | |
| def from_pb( | |
| cls, | |
| pb: generate_pb2.Batch, | |
| tokenizer: PreTrainedTokenizerBase, | |
| device: torch.device, | |
| ) -> "GalacticaCausalLMBatch": | |
| inputs = [] | |
| next_token_choosers = [] | |
| stopping_criterias = [] | |
| offsets = [] | |
| token_offsets = [] | |
| requests_idx_mapping = {} | |
| # Parse batch | |
| max_truncation = 0 | |
| padding_right_offset = 0 | |
| max_decode_tokens = 0 | |
| for i, r in enumerate(pb.requests): | |
| requests_idx_mapping[r.id] = i | |
| # Add escape_custom_split_sequence to the CausalLMBatch logic | |
| inputs.append(escape_custom_split_sequence(r.inputs)) | |
| offsets.append(None) | |
| token_offsets.append(None) | |
| next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) | |
| stopping_criteria = StoppingCriteria.from_pb( | |
| r.stopping_parameters, tokenizer | |
| ) | |
| stopping_criterias.append(stopping_criteria) | |
| max_truncation = max(max_truncation, r.truncate) | |
| max_decode_tokens += stopping_criteria.max_new_tokens | |
| padding_right_offset = max( | |
| padding_right_offset, stopping_criteria.max_new_tokens | |
| ) | |
| tokenized_inputs = tokenizer( | |
| inputs, | |
| return_tensors="pt", | |
| padding=True, | |
| return_token_type_ids=False, | |
| truncation=True, | |
| max_length=max_truncation, | |
| ).to(device) | |
| input_lengths = tokenized_inputs["attention_mask"].sum(1) | |
| max_input_length = input_lengths.max() | |
| input_ids = tokenized_inputs["input_ids"] | |
| # Allocate maximum attention_mask | |
| attention_mask = input_ids.new_zeros( | |
| (pb.size, max_input_length + padding_right_offset) | |
| ) | |
| # Copy tokenizer attention_mask into fully allocated attention_mask | |
| attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] | |
| position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) | |
| all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) | |
| max_tokens = len(inputs) * max_input_length + max_decode_tokens | |
| return cls( | |
| batch_id=pb.id, | |
| requests=pb.requests, | |
| requests_idx_mapping=requests_idx_mapping, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=None, | |
| all_input_ids=list(all_input_ids), | |
| input_lengths=input_lengths.tolist(), | |
| offsets=offsets, | |
| token_offsets=token_offsets, | |
| next_token_choosers=next_token_choosers, | |
| stopping_criterias=stopping_criterias, | |
| max_input_length=max_input_length.item(), | |
| padding_right_offset=padding_right_offset, | |
| max_tokens=max_tokens, | |
| ) | |
| class Galactica(OPT): | |
| def batch_type(self) -> Type[CausalLMBatch]: | |
| return GalacticaCausalLMBatch | |
| def decode(self, generated_ids: List[int]) -> str: | |
| # Do not skip special tokens as they are used for custom parsing rules of the generated text | |
| return self.tokenizer.decode( | |
| generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False | |
| ) | |
| def forward( | |
| self, input_ids, attention_mask, position_ids, past_key_values: Optional = None | |
| ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: | |
| """Overwrite forward to ignore position_ids""" | |
| # Model Forward | |
| outputs = self.model.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| use_cache=True, | |
| ) | |
| return outputs.logits, outputs.past_key_values | |
| class GalacticaSharded(Galactica): | |
| def __init__( | |
| self, model_id: str, revision: Optional[str] = None, quantize: bool = False | |
| ): | |
| self.process_group, self.rank, self.world_size = initialize_torch_distributed() | |
| self.master = self.rank == 0 | |
| if torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{self.rank}") | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
| else: | |
| device = torch.device("cpu") | |
| dtype = torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, revision=revision, padding_side="left", truncation_side="left" | |
| ) | |
| config = AutoConfig.from_pretrained( | |
| model_id, revision=revision, tp_parallel=True | |
| ) | |
| tokenizer.pad_token_id = config.pad_token_id | |
| torch.distributed.barrier(group=self.process_group) | |
| filenames = weight_files(model_id, revision=revision, extension=".safetensors") | |
| with init_empty_weights(): | |
| model = AutoModelForCausalLM.from_config(config) | |
| torch.distributed.barrier(group=self.process_group) | |
| self.load_weights( | |
| model, | |
| filenames, | |
| quantize=quantize, | |
| device=device, | |
| dtype=dtype, | |
| rank=self.rank, | |
| world_size=self.world_size, | |
| ) | |
| self.model = model.eval() | |
| torch.distributed.barrier(group=self.process_group) | |
| super(CausalLM, self).__init__( | |
| tokenizer=tokenizer, | |
| requires_padding=True, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def load_weights( | |
| model, | |
| filenames: List[str], | |
| quantize: bool, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| rank: int, | |
| world_size: int, | |
| ): | |
| parameters = dict(model.named_parameters()) | |
| for file in filenames: | |
| with safe_open( | |
| file, framework="pt", device=str(device) if not quantize else "cpu" | |
| ) as f: | |
| for name in f.keys(): | |
| if name == "lm_head.weight": | |
| continue | |
| module_name, param_name = name.rsplit(".", 1) | |
| module = model.get_submodule(module_name) | |
| current_tensor = parameters[name] | |
| slice_ = f.get_slice(name) | |
| if isinstance(module, TensorParallelColumnLinear): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| elif isinstance(module, TensorParallelRowLinear): | |
| if param_name == "weight": | |
| size = slice_.get_shape()[1] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[:, start:stop] | |
| else: | |
| tensor = slice_[:] | |
| # XXX: Hack for Rowlinear to add the bias only once. | |
| if rank != 0: | |
| tensor = torch.zeros_like(tensor) | |
| elif isinstance(module, TensorParallelEmbedding): | |
| size = slice_.get_shape()[0] | |
| block_size = size // world_size | |
| start = rank * block_size | |
| stop = (rank + 1) * block_size | |
| tensor = slice_[start:stop] | |
| else: | |
| tensor = slice_[:] | |
| if current_tensor.shape != tensor.shape: | |
| raise ValueError( | |
| f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" | |
| ) | |
| tensor = tensor.contiguous().to(dtype) | |
| if quantize: | |
| if not HAS_BITS_AND_BYTES: | |
| raise ImportError( | |
| "bitsandbytes is not available on your machine either because it is not installed " | |
| "or you don't have a GPU.\n" | |
| "You can install it with `pip install bitsandbytes`." | |
| ) | |
| if ( | |
| type(module) | |
| in [TensorParallelRowLinear, TensorParallelColumnLinear] | |
| and param_name == "weight" | |
| ): | |
| tensor = Int8Params( | |
| tensor, | |
| has_fp16_weights=False, | |
| requires_grad=False, | |
| ).to(device) | |
| state = bnb.MatmulLtState() | |
| state.threshold = 6.0 | |
| state.has_fp16_weights = False | |
| state.memory_efficient_backward = False | |
| state.use_pool = True | |
| state.CB = tensor.CB | |
| state.SCB = tensor.SCB | |
| tensor.CB = None | |
| tensor.SCB = None | |
| def replace_linear(state): | |
| def linear(input, weight, bias): | |
| out = bnb.matmul( | |
| input, | |
| weight, | |
| state=state, | |
| threshold=state.threshold, | |
| bias=bias, | |
| ) | |
| if state.CB is not None: | |
| # we converted 8-bit row major to turing/ampere format | |
| # in the first inference pass | |
| # we no longer need the row-major weight | |
| del state.CB | |
| weight.data = state.CxB | |
| return out | |
| return linear | |
| module.linear = replace_linear(state) | |
| else: | |
| tensor = tensor.to(device) | |
| module._parameters[param_name] = tensor | |
| if name == "model.decoder.embed_tokens.weight": | |
| model.lm_head._parameters["weight"] = tensor | |
| def forward( | |
| self, input_ids, attention_mask, position_ids, past_key_values: Optional = None | |
| ): | |
| outputs = self.model.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| use_cache=True, | |
| ) | |
| # Logits are sharded, so we need to gather them | |
| logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] | |
| torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) | |
| logits = torch.cat(logits, dim=2) | |
| return logits, outputs.past_key_values | |