Spaces:
Build error
Build error
| from transformers import AutoTokenizer, EncoderDecoderModel | |
| from transformers import pipeline as hf_pipeline | |
| from pathlib import Path | |
| import spaces | |
| import re | |
| from .app_logger import get_logger | |
| class NpcBertGPT2(): | |
| logger = get_logger() | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.pipeline = None | |
| # relative to app.py | |
| self.pretrained_model = "./models/npc-bert-gpt2-best" | |
| self.logger.info(f"Created {__class__.__name__} instance.") | |
| def load(self): | |
| """Loads the fine-tuned EncoderDecoder model and related components. | |
| This method initializes the model, tokenizer, and pipeline for the | |
| report conclusion generation task using the pre-trained weights from the | |
| specified directory. | |
| Raises: | |
| FileNotFoundError: If the pretrained model directory is not found. | |
| """ | |
| if not Path(self.pretrained_model).is_dir(): | |
| raise FileNotFoundError(f"Cannot found pretrained model at: {self.pretrained_model}") | |
| self.model = EncoderDecoderModel.from_pretrained(self.pretrained_model) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model) | |
| self.pipeline = hf_pipeline("text2text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device_map='auto', | |
| num_beams=4, | |
| do_sample=True, | |
| top_k = 5, | |
| temperature=.95, | |
| early_stopping=True, | |
| no_repeat_ngram_size=5, | |
| max_new_tokens=60) | |
| def __call__(self, *args): | |
| """Performs masked language modeling prediction. | |
| This method should be called only after the `load` method has been executed | |
| to ensure that the model and pipeline are properly initialized. It accepts | |
| arguments to pass to the Hugging Face fill-mask pipeline. | |
| Args: | |
| *args: Variable length argument list to pass to the pipeline. | |
| Returns: | |
| The output of the fill-mask pipeline. | |
| Raises: | |
| BrokenPipeError: If the model has not been loaded before calling this method. | |
| """ | |
| if self.pipeline is None: | |
| msg = "Model was not initialized, have you run load()?" | |
| raise BrokenPipeError(msg) | |
| self.logger.info(f"Model: {self.pipeline.model.device = }") | |
| pipe_out, = self.pipeline(*args) | |
| pipe_out = pipe_out['generated_text'] | |
| self.logger.info(f"Generated text: {pipe_out}") | |
| # remove repeated lines by hard coding | |
| mo = re.search("\. (questionable|anterio|zius)", pipe_out) | |
| if mo is not None: | |
| end_sig = mo.start() | |
| pipe_out = pipe_out[:end_sig + 1] | |
| self.logger.info(f"Displayed text: {pipe_out}") | |
| return pipe_out | |