Spaces:
Running
on
Zero
Running
on
Zero
| from models.model_manager import ModelManager | |
| import torch | |
| def tokenize_long_prompt(tokenizer, prompt, max_length=None): | |
| # Get model_max_length from self.tokenizer | |
| length = tokenizer.model_max_length if max_length is None else max_length | |
| # To avoid the warning. set self.tokenizer.model_max_length to +oo. | |
| tokenizer.model_max_length = 99999999 | |
| # Tokenize it! | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
| # Determine the real length. | |
| max_length = (input_ids.shape[1] + length - 1) // length * length | |
| # Restore tokenizer.model_max_length | |
| tokenizer.model_max_length = length | |
| # Tokenize it again with fixed length. | |
| input_ids = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| ).input_ids | |
| # Reshape input_ids to fit the text encoder. | |
| num_sentence = input_ids.shape[1] // length | |
| input_ids = input_ids.reshape((num_sentence, length)) | |
| return input_ids | |
| class BasePrompter: | |
| def __init__(self): | |
| self.refiners = [] | |
| self.extenders = [] | |
| def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): | |
| for refiner_class in refiner_classes: | |
| refiner = refiner_class.from_model_manager(model_manager) | |
| self.refiners.append(refiner) | |
| def load_prompt_extenders(self, model_manager: ModelManager, extender_classes=[]): | |
| for extender_class in extender_classes: | |
| extender = extender_class.from_model_manager(model_manager) | |
| self.extenders.append(extender) | |
| def process_prompt(self, prompt, positive=True): | |
| if isinstance(prompt, list): | |
| prompt = [ | |
| self.process_prompt(prompt_, positive=positive) for prompt_ in prompt | |
| ] | |
| else: | |
| for refiner in self.refiners: | |
| prompt = refiner(prompt, positive=positive) | |
| return prompt | |
| def extend_prompt(self, prompt: str, positive=True): | |
| extended_prompt = dict(prompt=prompt) | |
| for extender in self.extenders: | |
| extended_prompt = extender(extended_prompt) | |
| return extended_prompt | |