Spaces:
Runtime error
Runtime error
| """ | |
| constrained_generation.py - use constrained beam search to generate text from a model with entered constraints | |
| """ | |
| import copy | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| import time | |
| from pathlib import Path | |
| import yake | |
| from transformers import AutoTokenizer, PhrasalConstraint | |
| def get_tokenizer(model_name="gpt2", verbose=False): | |
| """ | |
| get_tokenizer - returns a tokenizer object | |
| :param model_name: name of the model to use, default gpt2 | |
| :param verbose: verbosity | |
| """ | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, add_special_tokens=False, padding=True, truncation=True | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if verbose: | |
| print(f"loaded tokenizer {model_name}") | |
| return tokenizer | |
| def unique_words(list_of_strings): | |
| """ | |
| unique_words - return a list of unique words from a list of strings. Uses set to remove duplicates. | |
| """ | |
| unique_words = [] | |
| output_list = [] | |
| for string in list_of_strings: | |
| # split string into words | |
| words = string.split() | |
| # check if word is unique | |
| unique_status = True | |
| for word in words: | |
| if word not in unique_words: | |
| unique_words.append(word) | |
| else: | |
| unique_status = False | |
| break | |
| if unique_status: | |
| output_list.append(string) | |
| return output_list | |
| def create_kw_extractor( | |
| language="en", | |
| max_ngram_size=3, | |
| deduplication_algo="seqm", | |
| windowSize=10, | |
| numOfKeywords=10, | |
| ddpt=0.7, | |
| ): | |
| """ | |
| creates a keyword extractor object | |
| :param language: language of the text | |
| :param max_ngram_size: max ngram size | |
| :param deduplication_algo: deduplication algorithm | |
| :param windowSize: window size | |
| :param numOfKeywords: number of keywords | |
| :param ddpt: Deduplication Percentage Threshold | |
| :return: keyword extractor object | |
| """ | |
| assert ddpt >= 0 and ddpt <= 1, f"need 0<thresh<1, got {ddpt}" | |
| return yake.KeywordExtractor( | |
| lan=language, | |
| n=max_ngram_size, | |
| dedupLim=ddpt, | |
| dedupFunc=deduplication_algo, | |
| windowsSize=windowSize, | |
| top=numOfKeywords, | |
| features=None, | |
| ) | |
| def simple_kw(body_text: str, yake_ex=None, max_kw=15, verbose=False): | |
| """ | |
| simple_kw - extract keywords from a text using yake | |
| Args: | |
| body_text (str): text to extract keywords from | |
| yake_ex (yake.KeywordExtractor, optional): yake keyword extractor. Defaults to None. | |
| max_kw (int, optional): maximum number of keywords to extract. Defaults to 10. | |
| verbose (bool, optional): Defaults to False. | |
| Returns: | |
| list: list of keywords | |
| """ | |
| yake_ex = yake_ex or create_kw_extractor( | |
| max_ngram_size=2, | |
| ddpt=0.9, | |
| windowSize=10, | |
| deduplication_algo="seqm", | |
| numOfKeywords=max_kw, | |
| ) # per optuna study | |
| keywords = yake_ex.extract_keywords(body_text) | |
| keywords_list = [str(kw[0]).lower() for kw in keywords] | |
| logging.info( | |
| f"YAKE: found {len(keywords_list)} keywords, the top {max_kw} are: {keywords_list[:max_kw]}" | |
| ) | |
| if verbose: | |
| print(f"found {len(keywords_list)} keywords, the top {max_kw} are:") | |
| print(keywords_list[:max_kw]) | |
| logging.info(f"found {len(keywords_list)} keywords, the top {max_kw} are:") | |
| return keywords_list[:max_kw] | |
| def constrained_generation( | |
| prompt: str, | |
| pipeline, | |
| tokenizer=None, | |
| no_repeat_ngram_size=2, | |
| length_penalty=0.7, | |
| repetition_penalty=3.5, | |
| num_beams=4, | |
| max_generated_tokens=48, | |
| min_generated_tokens=2, | |
| timeout=300, | |
| num_return_sequences=1, | |
| verbose=False, | |
| full_text=False, | |
| force_word: str = None, | |
| speaker_name: str = "Person Alpha", | |
| responder_name: str = "Person Beta", | |
| **kwargs, | |
| ): | |
| """ | |
| constrained_generation - generate text based on prompt and constraints | |
| USAGE | |
| ----- | |
| response = constrained_generation("hey man - how have you been lately?", | |
| tokenizer, my_chatbot, verbose=True, | |
| force_word=" meme", num_beams=32) | |
| Parameters | |
| ---------- | |
| prompt : str, prompt to use for generation, | |
| tokenizer : transformers.PreTrainedTokenizer, tokenizer to use, must be compatible with model | |
| pipeline : transformers.pipeline, pipeline to use, must be compatible with tokenizer & text2text model | |
| no_repeat_ngram_size : int, optional, default=2, | |
| num_beams : int, optional, default=8, | |
| max_generated_tokens : int, optional, default=64, | |
| min_generated_tokens : int, optional, default=16, | |
| verbose : bool, optional, default=False, print output | |
| force_word : _type_, optional, default=None, force word to be used in generation | |
| speaker_name : str, optional, default="Person Alpha", name of speaker | |
| responder_name : str, optional, default="Person Beta", name of responder | |
| Returns | |
| ------- | |
| response : str, generated text | |
| """ | |
| logging.debug(f" constraining generation with {locals()}") | |
| st = time.perf_counter() | |
| tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer) | |
| tokenizer.add_prefix_space = True | |
| tokenizer.add_special_tokens = False | |
| prompt_length = len(tokenizer(prompt, truncation=True).input_ids) | |
| if responder_name.lower() not in prompt.lower(): | |
| prompt = f"{prompt}\n\n{responder_name}:\n" | |
| # key_prompt_phrases = get_keyberts(prompt) | |
| key_prompt_phrases = simple_kw(prompt) | |
| try: | |
| responder_name_words = responder_name.lower().split() | |
| speaker_name_words = speaker_name.lower().split() | |
| except Exception as e: | |
| responder_name_words = [] | |
| speaker_name_words = [] | |
| logging.info(f"could not split names: {e}") | |
| key_prompt_phrases = [ | |
| p | |
| for p in key_prompt_phrases | |
| if not any([name in p for name in responder_name_words]) | |
| and not any([name in p for name in speaker_name_words]) | |
| ] | |
| force_flexible = unique_words(key_prompt_phrases) | |
| print(f"found keywords: {force_flexible}") | |
| if verbose: | |
| logging.info(f"found the following keywords: {force_flexible}") | |
| logging.info( | |
| f"forcing the word: {force_word}" | |
| ) if force_word is not None else logging.info("\n") | |
| else: | |
| logging.info(f"found the following keywords: {force_flexible}") | |
| if len(force_flexible) == 0: | |
| force_flexible = None | |
| constraints = ( | |
| [ | |
| PhrasalConstraint( | |
| tokenizer(force_word, add_special_tokens=False).input_ids, | |
| ), | |
| ] | |
| if force_word is not None | |
| else None | |
| ) | |
| force_words_ids = ( | |
| [ | |
| tokenizer( | |
| force_flexible, | |
| ).input_ids, | |
| ] | |
| if force_flexible is not None | |
| else None | |
| ) | |
| try: | |
| logging.info("generating text..") | |
| result = pipeline( | |
| prompt, | |
| constraints=constraints if force_word is not None else None, | |
| force_words_ids=force_words_ids if force_flexible is not None else None, | |
| max_length=None, | |
| max_new_tokens=max_generated_tokens, | |
| min_length=min_generated_tokens + prompt_length | |
| if full_text | |
| else min_generated_tokens, | |
| num_beams=num_beams, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| num_return_sequences=num_return_sequences, | |
| max_time=timeout, | |
| length_penalty=length_penalty, | |
| repetition_penalty=repetition_penalty, | |
| return_full_text=full_text, | |
| clean_up_tokenization_spaces=True, | |
| early_stopping=True, | |
| do_sample=False, | |
| **kwargs, | |
| ) | |
| response = result[0]["generated_text"] | |
| rt = round((time.perf_counter() - st) / 60, 3) | |
| logging.info(f"generated response in {rt} minutes") | |
| if verbose: | |
| print(f"input prompt:\n\t{prompt}") | |
| print(f"response:\n\t{response}") | |
| except Exception as e: | |
| logging.info(f"could not generate response: {e}") | |
| response = "Sorry, I don't know how to respond to that." | |
| return response | |