Spaces:
Runtime error
Runtime error
| import re | |
| import torch | |
| from transformers import ( | |
| LogitsProcessorList, | |
| TemperatureLogitsWarper, | |
| TopKLogitsWarper, | |
| TopPLogitsWarper, | |
| TypicalLogitsWarper, | |
| RepetitionPenaltyLogitsProcessor, | |
| PreTrainedTokenizerBase, | |
| ) | |
| from typing import List, Tuple, Optional | |
| from text_generation_server.pb import generate_pb2 | |
| from text_generation_server.pb.generate_pb2 import FinishReason | |
| from text_generation_server.utils.watermark import WatermarkLogitsProcessor | |
| class Sampling: | |
| def __init__(self, seed: int, device: str = "cpu"): | |
| self.generator = torch.Generator(device) | |
| self.generator.manual_seed(seed) | |
| self.seed = seed | |
| def __call__(self, logits): | |
| probs = torch.nn.functional.softmax(logits, -1) | |
| next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) | |
| return next_tokens | |
| class Greedy: | |
| def __call__(self, logits): | |
| return logits.argmax() | |
| class NextTokenChooser: | |
| def __init__( | |
| self, | |
| watermark=False, | |
| temperature=1.0, | |
| repetition_penalty=1.0, | |
| top_k=None, | |
| top_p=None, | |
| typical_p=None, | |
| do_sample=False, | |
| seed=0, | |
| device="cpu", | |
| ): | |
| warpers = LogitsProcessorList() | |
| # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files | |
| # all samplers can be found in `generation_utils_samplers.py` | |
| sampling = do_sample | |
| if watermark: | |
| warpers.append(WatermarkLogitsProcessor(device=device)) | |
| if repetition_penalty is not None and repetition_penalty != 1.0: | |
| warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) | |
| if temperature is not None and temperature != 1.0: | |
| temperature = float(temperature) | |
| warpers.append(TemperatureLogitsWarper(temperature)) | |
| sampling = True | |
| if top_k is not None and top_k != 0: | |
| warpers.append(TopKLogitsWarper(top_k=top_k)) | |
| sampling = True | |
| if top_p is not None and top_p < 1.0: | |
| warpers.append(TopPLogitsWarper(top_p=top_p)) | |
| sampling = True | |
| if typical_p is not None and typical_p < 1.0: | |
| warpers.append(TypicalLogitsWarper(mass=typical_p)) | |
| sampling = True | |
| self.warpers = warpers | |
| self.choice = Sampling(seed, device) if sampling else Greedy() | |
| def __call__(self, input_ids, scores): | |
| # Warp logits | |
| if scores.shape[0] > 1: | |
| # only warp the last token logits | |
| scores[-1:, :] = self.warpers(input_ids, scores[-1:, :]) | |
| else: | |
| scores = self.warpers(input_ids, scores) | |
| # Compute logprobs | |
| logprobs = torch.log_softmax(scores, -1) | |
| # Choose tokens | |
| next_id = self.choice(scores[-1]) | |
| return next_id.view(1, 1), logprobs | |
| def from_pb( | |
| cls, | |
| pb: generate_pb2.NextTokenChooserParameters, | |
| device: torch.device, | |
| ) -> "NextTokenChooser": | |
| return NextTokenChooser( | |
| watermark=pb.watermark, | |
| temperature=pb.temperature, | |
| repetition_penalty=pb.repetition_penalty, | |
| top_k=pb.top_k, | |
| top_p=pb.top_p, | |
| typical_p=pb.typical_p, | |
| do_sample=pb.do_sample, | |
| seed=pb.seed, | |
| device=device, | |
| ) | |
| class StopSequenceCriteria: | |
| def __init__(self, stop_sequence: str): | |
| stop_sequence = re.escape(stop_sequence) | |
| self.regex = re.compile(f".*{stop_sequence}$") | |
| def __call__(self, output: str) -> bool: | |
| if self.regex.findall(output): | |
| return True | |
| return False | |
| class StoppingCriteria: | |
| def __init__( | |
| self, | |
| eos_token_id: int, | |
| stop_sequence_criterias: List[StopSequenceCriteria], | |
| max_new_tokens: int = 20, | |
| ignore_eos_token: bool = False, | |
| ): | |
| self.eos_token_id = eos_token_id | |
| self.stop_sequence_criterias = stop_sequence_criterias | |
| self.max_new_tokens = max_new_tokens | |
| self.current_tokens = 0 | |
| self.current_output = "" | |
| self.ignore_eos_token = ignore_eos_token | |
| def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: | |
| self.current_tokens += 1 | |
| if self.current_tokens >= self.max_new_tokens: | |
| return True, FinishReason.FINISH_REASON_LENGTH | |
| if not self.ignore_eos_token and last_token == self.eos_token_id: | |
| return True, FinishReason.FINISH_REASON_EOS_TOKEN | |
| self.current_output += last_output | |
| for stop_sequence_criteria in self.stop_sequence_criterias: | |
| if stop_sequence_criteria(self.current_output): | |
| return True, FinishReason.FINISH_REASON_STOP_SEQUENCE | |
| return False, None | |
| def from_pb( | |
| cls, | |
| pb: generate_pb2.StoppingCriteriaParameters, | |
| tokenizer: PreTrainedTokenizerBase, | |
| ) -> "StoppingCriteria": | |
| stop_sequence_criterias = [ | |
| StopSequenceCriteria(sequence) for sequence in pb.stop_sequences | |
| ] | |
| return StoppingCriteria( | |
| tokenizer.eos_token_id, | |
| stop_sequence_criterias, | |
| pb.max_new_tokens, | |
| pb.ignore_eos_token, | |
| ) | |