Spaces:
Runtime error
Runtime error
| """ | |
| ai_single_response.py | |
| An executable way to call the model. example: | |
| *\gpt2_chatbot> python .\ai_single_response.py --prompt "where is the grocery store?" --time | |
| extended-summary: | |
| A system and method for interacting with a virtual machine using a series of messages , each message having associated otherwise one or more actions to be taken by the machine. The speaker participates in a chat with a responder , and the response from the responder is returned. | |
| """ | |
| import argparse | |
| import pprint as pp | |
| import time | |
| import warnings | |
| from datetime import datetime | |
| from pathlib import Path | |
| from cleantext import clean | |
| warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") | |
| from aitextgen import aitextgen | |
| def query_gpt_model( | |
| folder_path, | |
| prompt_msg: str, | |
| speaker=None, | |
| responder="person beta", | |
| kparam=150, | |
| temp=0.75, | |
| top_p=0.65, | |
| verbose=False, | |
| use_gpu=False, | |
| ): | |
| """ | |
| query_gpt_model [pass a prompt in to model, get a response. Does NOT "remember" past conversation] | |
| Args: | |
| folder_path ([type]): [description] | |
| prompt_msg (str): [description] | |
| speaker ([type], optional): [description]. Defaults to None. | |
| responder (str, optional): [description]. Defaults to "person beta". | |
| kparam (int, optional): [description]. Defaults to 125. | |
| temp (float, optional): [description]. Defaults to 0.75. | |
| top_p (float, optional): [description]. Defaults to 0.65. | |
| verbose (bool, optional): [description]. Defaults to False. | |
| use_gpu (bool, optional): [description]. Defaults to False. | |
| Returns: | |
| [dict]: [returns a dict with A) just model response as str B) total conversation] | |
| """ | |
| ai = aitextgen( | |
| model="r3dhummingbird/DialoGPT-medium-joshua", | |
| #model_folder=folder_path, | |
| to_gpu=False, | |
| ) | |
| print("loaded model") | |
| p_list = [] | |
| if "natqa" in str(folder_path).lower(): | |
| speaker = "person alpha" # manual correction | |
| responder = "person beta" | |
| if "wow" in str(folder_path).lower(): | |
| speaker = "person alpha" # manual correction | |
| responder = "person beta" | |
| if "peter" in str(folder_path).lower(): | |
| speaker = None # manual correction | |
| responder = "peter szemraj" | |
| if speaker is not None: | |
| p_list.append(speaker.lower() + ":" + "\n") # write prompt as the speaker | |
| p_list.append(prompt_msg.lower() + "\n") | |
| p_list.append("\n") | |
| p_list.append(responder.lower() + ":" + "\n") | |
| this_prompt = "".join(p_list) | |
| if verbose: | |
| print("overall prompt:\n") | |
| pp.pprint(this_prompt, indent=4) | |
| print("\n... generating... \n") | |
| this_result = ai.generate( | |
| n=1, | |
| top_k=kparam, | |
| batch_size=512, | |
| max_length=128, | |
| min_length=16, | |
| prompt=this_prompt, | |
| temperature=temp, | |
| top_p=top_p, | |
| do_sample=True, | |
| return_as_list=True, | |
| use_cache=True, | |
| ) | |
| if verbose: | |
| pp.pprint(this_result) # to see what is going on | |
| try: | |
| this_result = str(this_result[0]).split("\n") | |
| res_out = [clean(ele) for ele in this_result] | |
| p_out = [clean(ele) for ele in p_list] | |
| if verbose: | |
| pp.pprint(res_out) # to see what is going on | |
| pp.pprint(p_out) # to see what is going on | |
| diff_list = [] | |
| name_counter = 0 | |
| break_safe = False | |
| for resline in res_out: | |
| if (responder + ":") in resline: | |
| name_counter += 1 | |
| break_safe = True # next line a response from bot | |
| continue | |
| if ":" in resline and name_counter > 0: | |
| if break_safe: | |
| diff_list.append(resline) | |
| break_safe = False | |
| else: | |
| break | |
| if resline in p_out: | |
| break_safe = False | |
| continue | |
| else: | |
| diff_list.append(resline) | |
| break_safe = False | |
| if verbose: | |
| print("------------------------diff list: ") | |
| pp.pprint(diff_list) # to see what is going on | |
| print("---------------------------------") | |
| output = ", ".join(diff_list) | |
| except: | |
| output = "oops, there was an error. try again" | |
| p_list.append(output + "\n") | |
| p_list.append("\n") | |
| model_responses = {"out_text": output, "full_conv": p_list} | |
| print("finished!\n") | |
| return model_responses | |
| # Set up the parsing of command-line arguments | |
| def get_parser(): | |
| """ | |
| get_parser [a helper function for the argparse module] | |
| Returns: | |
| [argparse.ArgumentParser]: [the argparser relevant for this script] | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="submit a message and have a 774M parameter GPT model respond" | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| required=True, # MUST HAVE A PROMPT | |
| type=str, | |
| help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| required=False, | |
| type=str, | |
| # "gp2_DDandPeterTexts_774M_73Ksteps", - from GPT-Peter | |
| default="GPT2_trivNatQAdailydia_774M_175Ksteps", | |
| help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + " | |
| "config.json). No models? Run the script download_models.py", | |
| ) | |
| parser.add_argument( | |
| "--speaker", | |
| required=False, | |
| default=None, | |
| help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data", | |
| ) | |
| parser.add_argument( | |
| "--responder", | |
| required=False, | |
| default="person beta", | |
| help="who the responder is. Primarily relevant to bots trained on multi-individual chat data", | |
| ) | |
| parser.add_argument( | |
| "--topk", | |
| required=False, | |
| type=int, | |
| default=150, | |
| help="how many responses to sample (positive integer). lower = more random responses", | |
| ) | |
| parser.add_argument( | |
| "--temp", | |
| required=False, | |
| type=float, | |
| default=0.75, | |
| help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'", | |
| ) | |
| parser.add_argument( | |
| "--topp", | |
| required=False, | |
| type=float, | |
| default=0.65, | |
| help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?", | |
| ) | |
| parser.add_argument( | |
| "--verbose", | |
| default=False, | |
| action="store_true", | |
| help="pass this argument if you want all the printouts", | |
| ) | |
| parser.add_argument( | |
| "--time", | |
| default=False, | |
| action="store_true", | |
| help="pass this argument if you want to know runtime", | |
| ) | |
| return parser | |
| if __name__ == "__main__": | |
| args = get_parser().parse_args() | |
| query = args.prompt | |
| model_dir = str(args.model) | |
| model_loc = Path.cwd() / model_dir | |
| spkr = args.speaker | |
| rspndr = args.responder | |
| k_results = args.topk | |
| my_temp = args.temp | |
| my_top_p = args.topp | |
| want_verbose = args.verbose | |
| want_rt = args.time | |
| # force-update the speaker+responder params for the generic model case | |
| if "dailydialogue" in model_dir.lower(): | |
| spkr = "john smith" | |
| rspndr = "nancy sellers" | |
| # ^ arbitrary people created when parsing Daily Dialogue dataset | |
| # # force-update the speaker+responder params | |
| # for the generic model case | |
| if "natqa" in model_dir.lower(): | |
| spkr = "person alpha" | |
| rspndr = "person beta" | |
| # ^ arbitrary people created when parsing NatQA + TriviaQA + Daily Dialogue datasets | |
| st = time.time() | |
| resp = query_gpt_model( | |
| folder_path=model_loc, | |
| prompt_msg=query, | |
| speaker=spkr, | |
| responder=rspndr, | |
| kparam=k_results, | |
| temp=my_temp, | |
| top_p=my_top_p, | |
| verbose=want_verbose, | |
| use_gpu=False, | |
| ) | |
| output = resp["out_text"] | |
| pp.pprint(output, indent=4) | |
| # pp.pprint(this_result[3].strip(), indent=4) | |
| rt = round(time.time() - st, 1) | |
| if want_rt: | |
| print("took {runtime} seconds to generate. \n".format(runtime=rt)) | |
| if want_verbose: | |
| print("finished - ", datetime.now()) | |
| if want_verbose: | |
| p_list = resp["full_conv"] | |
| print("A transcript of your chat is as follows: \n") | |
| p_list = [item.strip() for item in p_list] | |
| pp.pprint(p_list) | |