Spaces:
Runtime error
Runtime error
| """ | |
| converse.py - this script has functions for handling the conversation between the user and the bot. | |
| https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size | |
| """ | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| import pprint as pp | |
| import time | |
| from grammar_improve import remove_trailing_punctuation | |
| from constrained_generation import constrained_generation | |
| def discussion( | |
| prompt_text: str, | |
| speaker: str, | |
| responder: str, | |
| pipeline, | |
| timeout=45, | |
| min_length=8, | |
| max_length=64, | |
| top_p=0.95, | |
| top_k=50, | |
| temperature=0.7, | |
| full_text=False, | |
| length_penalty=0.8, | |
| no_repeat_ngram_size=2, | |
| num_return_sequences=1, | |
| device=-1, | |
| verbose=False, | |
| constrained_beam_search=False, | |
| ): | |
| """ | |
| discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot. | |
| Parameters | |
| ---------- | |
| prompt_text : str, the prompt to ask the bot, usually the user's question | |
| speaker : str, the name of the person who is speaking the prompt | |
| responder : str, the name of the person who is responding to the prompt | |
| pipeline : transformers.Pipeline, the pipeline to use for generating the response | |
| timeout : int, optional, the number of seconds to wait before timing out, by default 45 | |
| max_length : int, optional, the maximum number of tokens to generate, defaults to 128 | |
| top_p : float, optional, the top probability to use for sampling, defaults to 0.95 | |
| top_k : int, optional, the top k to use for sampling, defaults to 50 | |
| temperature : float, optional, the temperature to use for sampling, defaults to 0.7 | |
| full_text : bool, optional, whether to return the full text or just the generated text, defaults to False | |
| num_return_sequences : int, optional, the number of sequences to return, defaults to 1 | |
| device : int, optional, the device to use for generation, defaults to -1 (CPU) | |
| verbose : bool, optional, whether to print the generated text, defaults to False | |
| Returns | |
| ------- | |
| str, the generated text | |
| """ | |
| logging.debug(f"input args: {locals()}") | |
| p_list = [] # track conversation | |
| p_list.append(speaker.lower() + ":" + "\n") | |
| p_list.append(prompt_text.lower() + "\n") | |
| p_list.append("\n") | |
| p_list.append(responder.lower() + ":" + "\n") | |
| this_prompt = "".join(p_list) | |
| if verbose: | |
| print("overall prompt:\n") | |
| pp.pprint(this_prompt, indent=4) | |
| if constrained_beam_search: | |
| logging.info("generating using constrained beam search ...") | |
| response = constrained_generation( | |
| prompt=this_prompt, | |
| pipeline=pipeline, | |
| min_generated_tokens=min_length, | |
| max_generated_tokens=max_length, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| length_penalty=length_penalty, | |
| repetition_penalty=1.0, | |
| num_beams=4, | |
| timeout=timeout, | |
| verbose=False, | |
| full_text=full_text, | |
| speaker_name=speaker, | |
| responder_name=responder, | |
| ) | |
| bot_dialogue = consolidate_texts( | |
| name_resp=responder, | |
| model_resp=response.split("\n"), | |
| name_spk=speaker, | |
| verbose=verbose, | |
| print_debug=True, | |
| ) | |
| else: | |
| logging.info("generating using sampling ...") | |
| bot_dialogue = gen_response( | |
| this_prompt, | |
| pipeline, | |
| speaker, | |
| responder, | |
| timeout=timeout, | |
| min_length=min_length, | |
| max_length=max_length, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| full_text=full_text, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| length_penalty=length_penalty, | |
| num_return_sequences=num_return_sequences, | |
| device=device, | |
| verbose=verbose, | |
| ) | |
| logging.debug(f"generation done. bot_dialogue: {bot_dialogue}") | |
| if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1: | |
| bot_resp = ", ".join(bot_dialogue) | |
| elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1: | |
| bot_resp = bot_dialogue[0] | |
| else: | |
| bot_resp = bot_dialogue | |
| bot_resp = " ".join(bot_resp) if isinstance(bot_resp, list) else bot_resp | |
| bot_resp = bot_resp.strip() | |
| # remove the last ',' '.' chars | |
| bot_resp = remove_trailing_punctuation(bot_resp) | |
| if verbose: | |
| print("\nfinished!") | |
| print("\n... bot response:\n") | |
| pp.pprint(bot_resp) | |
| p_list.append(bot_resp + "\n") | |
| p_list.append("\n") | |
| logging.info(f"finished generating response:\n\t{bot_resp}") | |
| # return the bot response and the full conversation | |
| return {"out_text": bot_resp, "full_conv": p_list} | |
| def gen_response( | |
| query: str, | |
| pipeline, | |
| speaker: str, | |
| responder: str, | |
| timeout=45, | |
| min_length=12, | |
| max_length=48, | |
| top_p=0.95, | |
| top_k=20, | |
| temperature=0.5, | |
| full_text=False, | |
| num_return_sequences=1, | |
| length_penalty: float = 0.8, | |
| repetition_penalty: float = 3.5, | |
| no_repeat_ngram_size=2, | |
| device=-1, | |
| verbose=False, | |
| **kwargs, | |
| ): | |
| """ | |
| gen_response - a function that takes in a prompt and generates a response using the pipeline. This operates underneath the discussion function. | |
| Parameters | |
| ---------- | |
| query : str, the prompt to ask the bot, usually the user's question | |
| speaker : str, the name of the person who is speaking the prompt | |
| responder : str, the name of the person who is responding to the prompt | |
| pipeline : transformers.Pipeline, the pipeline to use for generating the response | |
| timeout : int, optional, the number of seconds to wait before timing out, by default 45 | |
| min_length : int, optional, the minimum number of tokens to generate, defaults to 4 | |
| max_length : int, optional, the maximum number of tokens to generate, defaults to 64 | |
| top_p : float, optional, the top probability to use for sampling, defaults to 0.95 | |
| top_k : int, optional, the top k to use for sampling, defaults to 50 | |
| temperature : float, optional, the temperature to use for sampling, defaults to 0.7 | |
| full_text : bool, optional, whether to return the full text or just the generated text, defaults to False | |
| num_return_sequences : int, optional, the number of sequences to return, defaults to 1 | |
| device : int, optional, the device to use for generation, defaults to -1 (CPU) | |
| verbose : bool, optional, whether to print the generated text, defaults to False | |
| Returns | |
| ------- | |
| str, the generated text | |
| """ | |
| logging.debug(f"input args - gen_response() : {locals()}") | |
| input_len = len(pipeline.tokenizer(query).input_ids) | |
| if max_length + input_len > 1024: | |
| max_length = max(1024 - input_len, 8) | |
| print(f"max_length too large, setting to {max_length}") | |
| st = time.perf_counter() | |
| response = pipeline( | |
| query, | |
| min_length=min_length + input_len, | |
| max_length=max_length + input_len, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| num_return_sequences=num_return_sequences, | |
| max_time=timeout, | |
| return_full_text=full_text, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| clean_up_tokenization_spaces=True, | |
| remove_invalid_values=True, | |
| **kwargs, | |
| ) # the likely better beam-less method | |
| rt = round(time.perf_counter() - st, 2) | |
| if verbose: | |
| print(f"took {rt} sec to respond") | |
| if verbose: | |
| print("\n[DEBUG] generated:\n") | |
| pp.pprint(response) # for debugging | |
| # process the full result to get the ~bot response~ piece | |
| this_result = str(response[0]["generated_text"]).split( | |
| "\n" | |
| ) # TODO: adjust hardcoded value for index to dynamic (if n>1) | |
| bot_dialogue = consolidate_texts( | |
| name_resp=responder, | |
| model_resp=this_result, | |
| name_spk=speaker, | |
| verbose=verbose, | |
| print_debug=True, | |
| ) | |
| if verbose: | |
| print(f"DEBUG: {bot_dialogue} was original response pre-SC") | |
| return bot_dialogue # | |
| def consolidate_texts( | |
| model_resp: list, | |
| name_resp: str = None, | |
| name_spk: str = None, | |
| verbose=False, | |
| print_debug=False, | |
| ): | |
| """ | |
| consolidate_texts - given a list with speaker name followed by speaker text, returns all consecutive values of the first speaker name | |
| Parameters: | |
| name_resp (str): the name of the person who is responding | |
| model_resp (list): the list of strings to consolidate (usually from the model) | |
| name_spk (str): the name of the person who is speaking | |
| verbose (bool): whether to print the results | |
| print_debug (bool): whether to print the debug info during looping | |
| Returns: | |
| list, a list of all the consecutive messages of the first speaker name | |
| """ | |
| assert len(model_resp) > 0, "model_resp is empty" | |
| if len(model_resp) == 1: | |
| return model_resp[0] | |
| name_resp = "person beta" if name_resp is None else name_resp | |
| name_spk = "person alpha" if name_spk is None else name_spk | |
| if verbose: | |
| print("====" * 10) | |
| print( | |
| f"\n[DEBUG] initial model_resp has {len(model_resp)} lines: \n\t{model_resp}" | |
| ) | |
| print( | |
| f" the first element is \n\t{model_resp[0]} and it is {type(model_resp[0])}" | |
| ) | |
| fn_resp = [] | |
| name_counter = 0 | |
| break_safe = False | |
| for resline in model_resp: | |
| if name_resp.lower() in resline: | |
| name_counter += 1 | |
| break_safe = True # know the line is from bot as this line starts with the name of the bot | |
| continue # don't add this line to the list | |
| if name_spk.lower() in resline.lower(): | |
| if print_debug: | |
| print(f"\nDEBUG: \n\t{resline}\ncaused the break") | |
| break # the name of the speaker is in the line, so we're done | |
| if ( | |
| any([": " in resline, ":\n" in resline]) | |
| and name_resp.lower() not in resline.lower() | |
| ): | |
| if print_debug: | |
| print(f"\nDEBUG: \n\t{resline}\ncaused the break") | |
| break | |
| else: | |
| fn_resp.append(resline) | |
| break_safe = False | |
| if verbose: | |
| print("--" * 10) | |
| print("\nthe full response is:\n") | |
| print("\n".join(fn_resp)) | |
| print("--" * 10) | |
| return fn_resp | |