Spaces:
Runtime error
Runtime error
jonathanlehner
commited on
Commit
·
8c7c98a
1
Parent(s):
2ff1e50
added dialoggpt
Browse files- .gitignore +36 -0
- Pipfile +21 -0
- README 2.md +38 -0
- ai_single_response.py +278 -0
- app.py +196 -0
- config.json +34 -0
- file_test.py +3 -0
- requirements.txt +101 -0
- utils.py +282 -0
.gitignore
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python basics
|
| 2 |
+
/__pycache__/
|
| 3 |
+
/.idea/
|
| 4 |
+
/scratch/
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# local model folders for testing / running bots / deploy
|
| 8 |
+
|
| 9 |
+
/gpt2_std_gpu_774M_120ksteps/
|
| 10 |
+
/gpt2_std_gpu_774M_60ksteps/
|
| 11 |
+
/gpt2_dailydialogue_355M_75Ksteps/
|
| 12 |
+
/gp2_DDandPeterTexts_14kPeter_774M/
|
| 13 |
+
/gp2_DDandPeterTexts_41kPeter-774M/
|
| 14 |
+
/gp2_DDandPeterTexts_774M_73Ksteps/
|
| 15 |
+
/gp2_DDandPeterTexts_gpu_774M_175Ksteps/
|
| 16 |
+
*checkpoint*
|
| 17 |
+
*GPT2*
|
| 18 |
+
*GPTneo*
|
| 19 |
+
*GPTpeter*
|
| 20 |
+
*1pt3B*
|
| 21 |
+
|
| 22 |
+
# most of ^ can be downloaded through `download_models.py`
|
| 23 |
+
|
| 24 |
+
# gradio things
|
| 25 |
+
*.db
|
| 26 |
+
*.db-journal
|
| 27 |
+
*gradio_queue*
|
| 28 |
+
gradio_data
|
| 29 |
+
deploy-as-bot/flagged
|
| 30 |
+
deploy-as-bot/gradio_data
|
| 31 |
+
deploy-as-bot/gradio_queue.db
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# notebooks containing personal data
|
| 35 |
+
.DS_Store
|
| 36 |
+
aitextgen
|
Pipfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[[source]]
|
| 2 |
+
url = "https://pypi.org/simple"
|
| 3 |
+
verify_ssl = true
|
| 4 |
+
name = "pypi"
|
| 5 |
+
|
| 6 |
+
[packages]
|
| 7 |
+
natsort = "==7.1.1"
|
| 8 |
+
pandas = "==1.3.0"
|
| 9 |
+
symspellpy = "==6.7.0"
|
| 10 |
+
requests = "==2.24.0"
|
| 11 |
+
transformers = "==4.8.2"
|
| 12 |
+
gradio = "==1.7.7"
|
| 13 |
+
tqdm = "==4.43.0"
|
| 14 |
+
aitextgen = "==0.5.2"
|
| 15 |
+
cleantext = "==1.1.3"
|
| 16 |
+
telegram = "==0.0.1"
|
| 17 |
+
|
| 18 |
+
[dev-packages]
|
| 19 |
+
|
| 20 |
+
[requires]
|
| 21 |
+
python_version = "3.8"
|
README 2.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Ai Msgbot Gpt2 M XL
|
| 3 |
+
emoji: 📉
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Configuration
|
| 12 |
+
|
| 13 |
+
`title`: _string_
|
| 14 |
+
Display title for the Space
|
| 15 |
+
|
| 16 |
+
`emoji`: _string_
|
| 17 |
+
Space emoji (emoji-only character allowed)
|
| 18 |
+
|
| 19 |
+
`colorFrom`: _string_
|
| 20 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
| 21 |
+
|
| 22 |
+
`colorTo`: _string_
|
| 23 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
| 24 |
+
|
| 25 |
+
`sdk`: _string_
|
| 26 |
+
Can be either `gradio` or `streamlit`
|
| 27 |
+
|
| 28 |
+
`sdk_version` : _string_
|
| 29 |
+
Only applicable for `streamlit` SDK.
|
| 30 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
| 31 |
+
|
| 32 |
+
`app_file`: _string_
|
| 33 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
| 34 |
+
Path is relative to the root of the repository.
|
| 35 |
+
|
| 36 |
+
`pinned`: _boolean_
|
| 37 |
+
Whether the Space stays on top of your list.
|
| 38 |
+
|
ai_single_response.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ai_single_response.py
|
| 3 |
+
|
| 4 |
+
An executable way to call the model. example:
|
| 5 |
+
*\gpt2_chatbot> python .\ai_single_response.py --prompt "where is the grocery store?" --time
|
| 6 |
+
|
| 7 |
+
extended-summary:
|
| 8 |
+
|
| 9 |
+
A system and method for interacting with a virtual machine using a series of messages , each message having associated otherwise one or more actions to be taken by the machine. The speaker participates in a chat with a responder , and the response from the responder is returned.
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
import argparse
|
| 13 |
+
import pprint as pp
|
| 14 |
+
import time
|
| 15 |
+
import warnings
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from cleantext import clean
|
| 19 |
+
|
| 20 |
+
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
|
| 21 |
+
|
| 22 |
+
from aitextgen import aitextgen
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def query_gpt_model(
|
| 26 |
+
folder_path,
|
| 27 |
+
prompt_msg: str,
|
| 28 |
+
speaker=None,
|
| 29 |
+
responder="person beta",
|
| 30 |
+
kparam=150,
|
| 31 |
+
temp=0.75,
|
| 32 |
+
top_p=0.65,
|
| 33 |
+
verbose=False,
|
| 34 |
+
use_gpu=False,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
query_gpt_model [pass a prompt in to model, get a response. Does NOT "remember" past conversation]
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
folder_path ([type]): [description]
|
| 41 |
+
prompt_msg (str): [description]
|
| 42 |
+
speaker ([type], optional): [description]. Defaults to None.
|
| 43 |
+
responder (str, optional): [description]. Defaults to "person beta".
|
| 44 |
+
kparam (int, optional): [description]. Defaults to 125.
|
| 45 |
+
temp (float, optional): [description]. Defaults to 0.75.
|
| 46 |
+
top_p (float, optional): [description]. Defaults to 0.65.
|
| 47 |
+
verbose (bool, optional): [description]. Defaults to False.
|
| 48 |
+
use_gpu (bool, optional): [description]. Defaults to False.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
[dict]: [returns a dict with A) just model response as str B) total conversation]
|
| 52 |
+
"""
|
| 53 |
+
ai = aitextgen(
|
| 54 |
+
model="microsoft/DialoGPT-medium",
|
| 55 |
+
#model_folder=folder_path,
|
| 56 |
+
to_gpu=False,
|
| 57 |
+
)
|
| 58 |
+
print("loaded model")
|
| 59 |
+
p_list = []
|
| 60 |
+
if "natqa" in str(folder_path).lower():
|
| 61 |
+
speaker = "person alpha" # manual correction
|
| 62 |
+
responder = "person beta"
|
| 63 |
+
if "wow" in str(folder_path).lower():
|
| 64 |
+
speaker = "person alpha" # manual correction
|
| 65 |
+
responder = "person beta"
|
| 66 |
+
if "peter" in str(folder_path).lower():
|
| 67 |
+
speaker = None # manual correction
|
| 68 |
+
responder = "peter szemraj"
|
| 69 |
+
if speaker is not None:
|
| 70 |
+
p_list.append(speaker.lower() + ":" + "\n") # write prompt as the speaker
|
| 71 |
+
p_list.append(prompt_msg.lower() + "\n")
|
| 72 |
+
p_list.append("\n")
|
| 73 |
+
p_list.append(responder.lower() + ":" + "\n")
|
| 74 |
+
this_prompt = "".join(p_list)
|
| 75 |
+
if verbose:
|
| 76 |
+
print("overall prompt:\n")
|
| 77 |
+
pp.pprint(this_prompt, indent=4)
|
| 78 |
+
print("\n... generating... \n")
|
| 79 |
+
this_result = ai.generate(
|
| 80 |
+
n=1,
|
| 81 |
+
top_k=kparam,
|
| 82 |
+
batch_size=512,
|
| 83 |
+
max_length=128,
|
| 84 |
+
min_length=16,
|
| 85 |
+
prompt=this_prompt,
|
| 86 |
+
temperature=temp,
|
| 87 |
+
top_p=top_p,
|
| 88 |
+
do_sample=True,
|
| 89 |
+
return_as_list=True,
|
| 90 |
+
use_cache=True,
|
| 91 |
+
)
|
| 92 |
+
if verbose:
|
| 93 |
+
pp.pprint(this_result) # to see what is going on
|
| 94 |
+
try:
|
| 95 |
+
this_result = str(this_result[0]).split("\n")
|
| 96 |
+
res_out = [clean(ele) for ele in this_result]
|
| 97 |
+
p_out = [clean(ele) for ele in p_list]
|
| 98 |
+
if verbose:
|
| 99 |
+
pp.pprint(res_out) # to see what is going on
|
| 100 |
+
pp.pprint(p_out) # to see what is going on
|
| 101 |
+
|
| 102 |
+
diff_list = []
|
| 103 |
+
name_counter = 0
|
| 104 |
+
break_safe = False
|
| 105 |
+
for resline in res_out:
|
| 106 |
+
|
| 107 |
+
if (responder + ":") in resline:
|
| 108 |
+
name_counter += 1
|
| 109 |
+
break_safe = True # next line a response from bot
|
| 110 |
+
continue
|
| 111 |
+
if ":" in resline and name_counter > 0:
|
| 112 |
+
if break_safe:
|
| 113 |
+
diff_list.append(resline)
|
| 114 |
+
break_safe = False
|
| 115 |
+
else:
|
| 116 |
+
break
|
| 117 |
+
if resline in p_out:
|
| 118 |
+
break_safe = False
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
diff_list.append(resline)
|
| 123 |
+
break_safe = False
|
| 124 |
+
|
| 125 |
+
if verbose:
|
| 126 |
+
print("------------------------diff list: ")
|
| 127 |
+
pp.pprint(diff_list) # to see what is going on
|
| 128 |
+
print("---------------------------------")
|
| 129 |
+
|
| 130 |
+
output = ", ".join(diff_list)
|
| 131 |
+
|
| 132 |
+
except:
|
| 133 |
+
output = "oops, there was an error. try again"
|
| 134 |
+
|
| 135 |
+
p_list.append(output + "\n")
|
| 136 |
+
p_list.append("\n")
|
| 137 |
+
|
| 138 |
+
model_responses = {"out_text": output, "full_conv": p_list}
|
| 139 |
+
print("finished!\n")
|
| 140 |
+
|
| 141 |
+
return model_responses
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# Set up the parsing of command-line arguments
|
| 145 |
+
def get_parser():
|
| 146 |
+
"""
|
| 147 |
+
get_parser [a helper function for the argparse module]
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
[argparse.ArgumentParser]: [the argparser relevant for this script]
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
parser = argparse.ArgumentParser(
|
| 154 |
+
description="submit a message and have a 774M parameter GPT model respond"
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--prompt",
|
| 158 |
+
required=True, # MUST HAVE A PROMPT
|
| 159 |
+
type=str,
|
| 160 |
+
help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--model",
|
| 164 |
+
required=False,
|
| 165 |
+
type=str,
|
| 166 |
+
# "gp2_DDandPeterTexts_774M_73Ksteps", - from GPT-Peter
|
| 167 |
+
default="GPT2_trivNatQAdailydia_774M_175Ksteps",
|
| 168 |
+
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
|
| 169 |
+
"config.json). No models? Run the script download_models.py",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--speaker",
|
| 174 |
+
required=False,
|
| 175 |
+
default=None,
|
| 176 |
+
help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--responder",
|
| 180 |
+
required=False,
|
| 181 |
+
default="person beta",
|
| 182 |
+
help="who the responder is. Primarily relevant to bots trained on multi-individual chat data",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--topk",
|
| 187 |
+
required=False,
|
| 188 |
+
type=int,
|
| 189 |
+
default=150,
|
| 190 |
+
help="how many responses to sample (positive integer). lower = more random responses",
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--temp",
|
| 195 |
+
required=False,
|
| 196 |
+
type=float,
|
| 197 |
+
default=0.75,
|
| 198 |
+
help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--topp",
|
| 203 |
+
required=False,
|
| 204 |
+
type=float,
|
| 205 |
+
default=0.65,
|
| 206 |
+
help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?",
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--verbose",
|
| 211 |
+
default=False,
|
| 212 |
+
action="store_true",
|
| 213 |
+
help="pass this argument if you want all the printouts",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--time",
|
| 217 |
+
default=False,
|
| 218 |
+
action="store_true",
|
| 219 |
+
help="pass this argument if you want to know runtime",
|
| 220 |
+
)
|
| 221 |
+
return parser
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
args = get_parser().parse_args()
|
| 226 |
+
query = args.prompt
|
| 227 |
+
model_dir = str(args.model)
|
| 228 |
+
model_loc = Path.cwd() / model_dir
|
| 229 |
+
spkr = args.speaker
|
| 230 |
+
rspndr = args.responder
|
| 231 |
+
k_results = args.topk
|
| 232 |
+
my_temp = args.temp
|
| 233 |
+
my_top_p = args.topp
|
| 234 |
+
want_verbose = args.verbose
|
| 235 |
+
want_rt = args.time
|
| 236 |
+
|
| 237 |
+
# force-update the speaker+responder params for the generic model case
|
| 238 |
+
if "dailydialogue" in model_dir.lower():
|
| 239 |
+
spkr = "john smith"
|
| 240 |
+
rspndr = "nancy sellers"
|
| 241 |
+
# ^ arbitrary people created when parsing Daily Dialogue dataset
|
| 242 |
+
# # force-update the speaker+responder params
|
| 243 |
+
# for the generic model case
|
| 244 |
+
if "natqa" in model_dir.lower():
|
| 245 |
+
spkr = "person alpha"
|
| 246 |
+
rspndr = "person beta"
|
| 247 |
+
# ^ arbitrary people created when parsing NatQA + TriviaQA + Daily Dialogue datasets
|
| 248 |
+
|
| 249 |
+
st = time.time()
|
| 250 |
+
|
| 251 |
+
resp = query_gpt_model(
|
| 252 |
+
folder_path=model_loc,
|
| 253 |
+
prompt_msg=query,
|
| 254 |
+
speaker=spkr,
|
| 255 |
+
responder=rspndr,
|
| 256 |
+
kparam=k_results,
|
| 257 |
+
temp=my_temp,
|
| 258 |
+
top_p=my_top_p,
|
| 259 |
+
verbose=want_verbose,
|
| 260 |
+
use_gpu=False,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
output = resp["out_text"]
|
| 264 |
+
pp.pprint(output, indent=4)
|
| 265 |
+
|
| 266 |
+
# pp.pprint(this_result[3].strip(), indent=4)
|
| 267 |
+
rt = round(time.time() - st, 1)
|
| 268 |
+
|
| 269 |
+
if want_rt:
|
| 270 |
+
print("took {runtime} seconds to generate. \n".format(runtime=rt))
|
| 271 |
+
|
| 272 |
+
if want_verbose:
|
| 273 |
+
print("finished - ", datetime.now())
|
| 274 |
+
if want_verbose:
|
| 275 |
+
p_list = resp["full_conv"]
|
| 276 |
+
print("A transcript of your chat is as follows: \n")
|
| 277 |
+
p_list = [item.strip() for item in p_list]
|
| 278 |
+
pp.pprint(p_list)
|
app.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
|
| 3 |
+
deploy-as-bot\gradio_chatbot.py
|
| 4 |
+
|
| 5 |
+
A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
|
| 6 |
+
Note that the URL is displayed once the script is run.
|
| 7 |
+
|
| 8 |
+
Set the working directory to */deploy-as-bot in terminal before running.
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from os.path import dirname
|
| 14 |
+
|
| 15 |
+
sys.path.append(dirname(dirname(os.path.abspath(__file__))))
|
| 16 |
+
|
| 17 |
+
import gradio as gr
|
| 18 |
+
import logging
|
| 19 |
+
import argparse
|
| 20 |
+
import time
|
| 21 |
+
import warnings
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from cleantext import clean
|
| 24 |
+
from transformers import pipeline
|
| 25 |
+
from datetime import datetime
|
| 26 |
+
from ai_single_response import query_gpt_model
|
| 27 |
+
#from gradio.networking import get_state, set_state
|
| 28 |
+
from flask import Flask, request, session, jsonify, abort, send_file, render_template, redirect
|
| 29 |
+
|
| 30 |
+
import nltk
|
| 31 |
+
nltk.download('stopwords')
|
| 32 |
+
|
| 33 |
+
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
|
| 34 |
+
|
| 35 |
+
logging.basicConfig()
|
| 36 |
+
cwd = Path.cwd()
|
| 37 |
+
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def gramformer_correct(corrector, qphrase: str):
|
| 41 |
+
"""
|
| 42 |
+
gramformer_correct - correct a string using a text2textgen pipeline model from transformers
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
|
| 46 |
+
qphrase (str): [text to be corrected]
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
[str]: [corrected text]
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
corrected = corrector(
|
| 54 |
+
clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
|
| 55 |
+
)
|
| 56 |
+
return corrected[0]["generated_text"]
|
| 57 |
+
except:
|
| 58 |
+
print("NOTE - failed to correct with gramformer")
|
| 59 |
+
return clean(qphrase)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def ask_gpt(message: str, sender: str = ""):
|
| 63 |
+
"""
|
| 64 |
+
ask_gpt - queries the relevant model with a prompt message and (optional) speaker name
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
message (str): prompt message to respond to
|
| 68 |
+
sender (str, optional): speaker aka who said the message. Defaults to "".
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
[str]: [model response as a string]
|
| 72 |
+
"""
|
| 73 |
+
st = time.time()
|
| 74 |
+
prompt = clean(message) # clean user input
|
| 75 |
+
prompt = prompt.strip() # get rid of any extra whitespace
|
| 76 |
+
if len(prompt) > 200:
|
| 77 |
+
prompt = prompt[-200:] # truncate
|
| 78 |
+
sender = clean(sender.strip())
|
| 79 |
+
if len(sender) > 2:
|
| 80 |
+
try:
|
| 81 |
+
prompt_speaker = clean(sender)
|
| 82 |
+
except:
|
| 83 |
+
# there was some issue getting that info, whatever
|
| 84 |
+
prompt_speaker = None
|
| 85 |
+
else:
|
| 86 |
+
prompt_speaker = None
|
| 87 |
+
|
| 88 |
+
resp = query_gpt_model(
|
| 89 |
+
folder_path=model_loc,
|
| 90 |
+
prompt_msg=prompt,
|
| 91 |
+
speaker=prompt_speaker,
|
| 92 |
+
kparam=150,
|
| 93 |
+
temp=0.75,
|
| 94 |
+
top_p=0.65, # optimize this with hyperparam search
|
| 95 |
+
)
|
| 96 |
+
bot_resp = gramformer_correct(corrector, qphrase=resp["out_text"])
|
| 97 |
+
rt = round(time.time() - st, 2)
|
| 98 |
+
print(f"took {rt} sec to respond")
|
| 99 |
+
|
| 100 |
+
return bot_resp
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def chat(first_and_last_name, message):
|
| 104 |
+
"""
|
| 105 |
+
chat - helper function that makes the whole gradio thing work.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
first_and_last_name (str or None): [speaker of the prompt, if provided]
|
| 109 |
+
message (str): [description]
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
[str]: [returns an html string to display]
|
| 113 |
+
"""
|
| 114 |
+
history = session.get("my_state") or []
|
| 115 |
+
response = ask_gpt(message, sender=first_and_last_name)
|
| 116 |
+
history.append((f"{first_and_last_name}: " + message, " GPT-Model: " + response)) #+ " [end] "))
|
| 117 |
+
session["my_state"] = history
|
| 118 |
+
session.modified = True
|
| 119 |
+
#html = "<div class='chatbot'>"
|
| 120 |
+
#for user_msg, resp_msg in history:
|
| 121 |
+
# html += f"<div class='user_msg'>{user_msg}</div>"
|
| 122 |
+
# html += f"<div class='resp_msg' style='color: black'>{resp_msg}</div>"
|
| 123 |
+
#html += "</div>"
|
| 124 |
+
return response
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_parser():
|
| 128 |
+
"""
|
| 129 |
+
get_parser - a helper function for the argparse module
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
[argparse.ArgumentParser]: [the argparser relevant for this script]
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
parser = argparse.ArgumentParser(
|
| 136 |
+
description="submit a message and have a 774M parameter GPT model respond"
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--model",
|
| 140 |
+
required=False,
|
| 141 |
+
type=str,
|
| 142 |
+
# "gp2_DDandPeterTexts_774M_73Ksteps", - from GPT-Peter
|
| 143 |
+
default="GPT2_trivNatQAdailydia_774M_175Ksteps",
|
| 144 |
+
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
|
| 145 |
+
"config.json). No models? Run the script download_models.py",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--gram-model",
|
| 150 |
+
required=False,
|
| 151 |
+
type=str,
|
| 152 |
+
default="prithivida/grammar_error_correcter_v1",
|
| 153 |
+
help="text2text generation model ID from huggingface for the model to correct grammar",
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return parser
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
args = get_parser().parse_args()
|
| 161 |
+
default_model = str(args.model)
|
| 162 |
+
model_loc = cwd.parent / default_model
|
| 163 |
+
model_loc = str(model_loc.resolve())
|
| 164 |
+
gram_model = args.gram_model
|
| 165 |
+
print(f"using model stored here: \n {model_loc} \n")
|
| 166 |
+
corrector = pipeline("text2text-generation", model=gram_model, device=-1)
|
| 167 |
+
print("Finished loading the gramformer model - ", datetime.now())
|
| 168 |
+
iface = gr.Interface(
|
| 169 |
+
chat,
|
| 170 |
+
inputs=["text", "text"],
|
| 171 |
+
outputs="html",
|
| 172 |
+
title="Real-Impact English Chat Demo 英语聊天演示",
|
| 173 |
+
description="A basic interface with a neural network model trained on general Q&A and conversation. Treat it like a friend! 带有模型的基本界面,进行了一般问答和对话训练。 请像朋友一样与他对话! \n first and last name 姓名 \n message 信息 \n Clear 清除 \nSubmit 确认 \n Screenshot 截屏",
|
| 174 |
+
article="**Important Notes & About: 重要说明 & 关于我们**\n"
|
| 175 |
+
"1. the model can take up to 200 seconds to respond sometimes, patience is a virtue. 该模型有时可能需要长达 60 秒的响应时间,请耐心等待。\n"
|
| 176 |
+
"2. entering a username is completely optional. 姓名输入是可选的。\n "
|
| 177 |
+
"3. the model was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement. 该模型在几个不同的数据集上训练而成,它所说的任何内容都应该经过事实核查,然后才能被视为真实陈述。\n ",
|
| 178 |
+
css="""
|
| 179 |
+
.chatbox {display:flex;flex-direction:column}
|
| 180 |
+
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
| 181 |
+
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
|
| 182 |
+
.resp_msg {background-color:lightgray;align-self:self-end}
|
| 183 |
+
""",
|
| 184 |
+
allow_screenshot=True,
|
| 185 |
+
allow_flagging=False,
|
| 186 |
+
flagging_dir="gradio_data",
|
| 187 |
+
flagging_options=[
|
| 188 |
+
"great response",
|
| 189 |
+
"doesn't make sense",
|
| 190 |
+
"bad/offensive response",
|
| 191 |
+
],
|
| 192 |
+
enable_queue=True, # allows for dealing with multiple users simultaneously
|
| 193 |
+
#theme="darkhuggingface",
|
| 194 |
+
#server_name="0.0.0.0",
|
| 195 |
+
)
|
| 196 |
+
iface.launch(share=True)
|
config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/content/drive/MyDrive/Programming/AI_peter/gpt2_dailydialogue_gpu_355M",
|
| 3 |
+
"activation_function": "gelu_new",
|
| 4 |
+
"architectures": [
|
| 5 |
+
"GPT2LMHeadModel"
|
| 6 |
+
],
|
| 7 |
+
"attn_pdrop": 0.1,
|
| 8 |
+
"bos_token_id": 50256,
|
| 9 |
+
"embd_pdrop": 0.1,
|
| 10 |
+
"eos_token_id": 50256,
|
| 11 |
+
"gradient_checkpointing": true,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"layer_norm_epsilon": 1e-05,
|
| 14 |
+
"line_by_line": false,
|
| 15 |
+
"model_type": "gpt2",
|
| 16 |
+
"n_ctx": 1024,
|
| 17 |
+
"n_embd": 1024,
|
| 18 |
+
"n_head": 16,
|
| 19 |
+
"n_inner": null,
|
| 20 |
+
"n_layer": 24,
|
| 21 |
+
"n_positions": 1024,
|
| 22 |
+
"n_vocab": 50257,
|
| 23 |
+
"resid_pdrop": 0.1,
|
| 24 |
+
"scale_attn_weights": true,
|
| 25 |
+
"summary_activation": null,
|
| 26 |
+
"summary_first_dropout": 0.1,
|
| 27 |
+
"summary_proj_to_labels": true,
|
| 28 |
+
"summary_type": "cls_index",
|
| 29 |
+
"summary_use_proj": true,
|
| 30 |
+
"torch_dtype": "float32",
|
| 31 |
+
"transformers_version": "4.11.3",
|
| 32 |
+
"use_cache": false,
|
| 33 |
+
"vocab_size": 50257
|
| 34 |
+
}
|
file_test.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
print(os.path.exists("/Users/jonathan/ai-msgbot/gpt2_dailydialogue_355M_150Ksteps/pytorch_model.bin"))
|
requirements.txt
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==1.0.0
|
| 2 |
+
aiohttp==3.8.1
|
| 3 |
+
aiosignal==1.2.0
|
| 4 |
+
aitextgen==0.5.2
|
| 5 |
+
analytics-python==1.4.0
|
| 6 |
+
APScheduler==3.6.3
|
| 7 |
+
async-timeout==4.0.2
|
| 8 |
+
attrs==21.2.0
|
| 9 |
+
backoff==1.10.0
|
| 10 |
+
backports.zoneinfo==0.2.1
|
| 11 |
+
bcrypt==3.2.0
|
| 12 |
+
cachetools==4.2.2
|
| 13 |
+
certifi==2021.10.8
|
| 14 |
+
cffi==1.15.0
|
| 15 |
+
chardet==3.0.4
|
| 16 |
+
charset-normalizer==2.0.9
|
| 17 |
+
cleantext==1.1.3
|
| 18 |
+
click==8.0.3
|
| 19 |
+
cryptography==36.0.1
|
| 20 |
+
cycler==0.11.0
|
| 21 |
+
editdistpy==0.1.3
|
| 22 |
+
ffmpy==0.3.0
|
| 23 |
+
filelock==3.4.2
|
| 24 |
+
fire==0.4.0
|
| 25 |
+
Flask==2.0.2
|
| 26 |
+
Flask-CacheBuster==1.0.0
|
| 27 |
+
Flask-Cors==3.0.10
|
| 28 |
+
Flask-Login==0.5.0
|
| 29 |
+
fonttools==4.28.5
|
| 30 |
+
frozenlist==1.2.0
|
| 31 |
+
fsspec==2021.11.1
|
| 32 |
+
future==0.18.2
|
| 33 |
+
google-auth==2.3.3
|
| 34 |
+
google-auth-oauthlib==0.4.6
|
| 35 |
+
gradio==2.4.6
|
| 36 |
+
grpcio==1.43.0
|
| 37 |
+
huggingface-hub==0.2.1
|
| 38 |
+
idna==2.10
|
| 39 |
+
importlib-metadata==4.10.0
|
| 40 |
+
itsdangerous==2.0.1
|
| 41 |
+
Jinja2==3.0.3
|
| 42 |
+
joblib==1.1.0
|
| 43 |
+
kiwisolver==1.3.2
|
| 44 |
+
Markdown==3.3.6
|
| 45 |
+
markdown2==2.4.2
|
| 46 |
+
MarkupSafe==2.0.1
|
| 47 |
+
matplotlib==3.5.1
|
| 48 |
+
monotonic==1.6
|
| 49 |
+
multidict==5.2.0
|
| 50 |
+
natsort==7.1.1
|
| 51 |
+
nltk==3.6.6
|
| 52 |
+
numpy==1.21.5
|
| 53 |
+
oauthlib==3.1.1
|
| 54 |
+
openwa==1.3.16
|
| 55 |
+
packaging==21.3
|
| 56 |
+
pandas==1.3.5
|
| 57 |
+
paramiko==2.9.1
|
| 58 |
+
Pillow==8.4.0
|
| 59 |
+
protobuf==3.19.1
|
| 60 |
+
pyasn1==0.4.8
|
| 61 |
+
pyasn1-modules==0.2.8
|
| 62 |
+
pycparser==2.21
|
| 63 |
+
pycryptodome==3.12.0
|
| 64 |
+
pyDeprecate==0.3.1
|
| 65 |
+
pydub==0.25.1
|
| 66 |
+
PyNaCl==1.4.0
|
| 67 |
+
pyparsing==3.0.6
|
| 68 |
+
python-axolotl==0.2.3
|
| 69 |
+
python-axolotl-curve25519==0.4.1.post2
|
| 70 |
+
python-dateutil==2.8.2
|
| 71 |
+
python-telegram-bot==13.8.1
|
| 72 |
+
pytorch-lightning==1.5.7
|
| 73 |
+
pytz==2021.3
|
| 74 |
+
pytz-deprecation-shim==0.1.0.post0
|
| 75 |
+
PyYAML==6.0
|
| 76 |
+
regex==2021.11.10
|
| 77 |
+
requests==2.24.0
|
| 78 |
+
requests-oauthlib==1.3.0
|
| 79 |
+
rsa==4.8
|
| 80 |
+
sacremoses==0.0.46
|
| 81 |
+
selenium==3.141.0
|
| 82 |
+
six==1.16.0
|
| 83 |
+
symspellpy==6.7.6
|
| 84 |
+
tensorboard==2.7.0
|
| 85 |
+
tensorboard-data-server==0.6.1
|
| 86 |
+
tensorboard-plugin-wit==1.8.0
|
| 87 |
+
termcolor==1.1.0
|
| 88 |
+
tokenizers==0.10.3
|
| 89 |
+
torch==1.10.1
|
| 90 |
+
torchmetrics==0.6.2
|
| 91 |
+
tornado==6.1
|
| 92 |
+
tqdm==4.43.0
|
| 93 |
+
transformers==4.12.5
|
| 94 |
+
typing_extensions==4.0.1
|
| 95 |
+
tzdata==2021.5
|
| 96 |
+
tzlocal==4.1
|
| 97 |
+
urllib3==1.25.11
|
| 98 |
+
webwhatsapi==2.0.5
|
| 99 |
+
Werkzeug==2.0.2
|
| 100 |
+
yarl==1.7.2
|
| 101 |
+
zipp==3.6.0
|
utils.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
general utility functions for loading, saving, etc
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import pprint as pp
|
| 7 |
+
import re
|
| 8 |
+
import shutil # zipfile formats
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from os.path import basename
|
| 11 |
+
from os.path import getsize, join
|
| 12 |
+
|
| 13 |
+
import requests
|
| 14 |
+
from cleantext import clean
|
| 15 |
+
from natsort import natsorted
|
| 16 |
+
from symspellpy import SymSpell
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_timestamp():
|
| 22 |
+
return datetime.now().strftime("%b-%d-%Y_t-%H")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def correct_phrase_load(my_string: str):
|
| 26 |
+
"""
|
| 27 |
+
correct_phrase_load [basic / unoptimized implementation of SymSpell to correct a string]
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
my_string (str): [text to be corrected]
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
[type]: [description]
|
| 34 |
+
"""
|
| 35 |
+
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
|
| 36 |
+
|
| 37 |
+
dictionary_path = (
|
| 38 |
+
r"symspell_rsc/frequency_dictionary_en_82_765.txt" # from repo root
|
| 39 |
+
)
|
| 40 |
+
bigram_path = (
|
| 41 |
+
r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" # from repo root
|
| 42 |
+
)
|
| 43 |
+
# term_index is the column of the term and count_index is the
|
| 44 |
+
# column of the term frequency
|
| 45 |
+
sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)
|
| 46 |
+
sym_spell.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
|
| 47 |
+
|
| 48 |
+
# max edit distance per lookup (per single word, not per whole input string)
|
| 49 |
+
suggestions = sym_spell.lookup_compound(
|
| 50 |
+
clean(my_string), max_edit_distance=2, ignore_non_words=True
|
| 51 |
+
)
|
| 52 |
+
if len(suggestions) < 1:
|
| 53 |
+
return my_string
|
| 54 |
+
else:
|
| 55 |
+
first_result = suggestions[0]
|
| 56 |
+
return first_result._term
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def fast_scandir(dirname: str):
|
| 60 |
+
"""
|
| 61 |
+
fast_scandir [an os.path-based means to return all subfolders in a given filepath]
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
dirname (str): [description]
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
[list]: [description]
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
subfolders = [f.path for f in os.scandir(dirname) if f.is_dir()]
|
| 71 |
+
for dirname in list(subfolders):
|
| 72 |
+
subfolders.extend(fast_scandir(dirname))
|
| 73 |
+
return subfolders # list
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def create_folder(directory: str):
|
| 77 |
+
|
| 78 |
+
os.makedirs(directory, exist_ok=True)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def chunks(lst: list, n: int):
|
| 82 |
+
"""
|
| 83 |
+
chunks - Yield successive n-sized chunks from lst
|
| 84 |
+
Args:
|
| 85 |
+
lst (list): [description]
|
| 86 |
+
n (int): [description]
|
| 87 |
+
|
| 88 |
+
Yields:
|
| 89 |
+
[type]: [description]
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
for i in range(0, len(lst), n):
|
| 93 |
+
yield lst[i : i + n]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def chunky_pandas(my_df, num_chunks: int = 4):
|
| 97 |
+
"""
|
| 98 |
+
chunky_pandas [split dataframe into `num_chunks` equal chunks, return each inside a list]
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
my_df (pd.DataFrame): [description]
|
| 102 |
+
num_chunks (int, optional): [description]. Defaults to 4.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
[type]: [description]
|
| 106 |
+
"""
|
| 107 |
+
n = int(len(my_df) // num_chunks)
|
| 108 |
+
list_df = [my_df[i : i + n] for i in range(0, my_df.shape[0], n)]
|
| 109 |
+
|
| 110 |
+
return list_df
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_dir_files(
|
| 114 |
+
directory: str, req_extension=".txt", return_type="list", verbose=False
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
load_dir_files - an os.path based method of returning all files with extension `req_extension` in a given directory and subdirectories
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
directory (str): [description]
|
| 121 |
+
req_extension (str, optional): [description]. Defaults to ".txt".
|
| 122 |
+
return_type (str, optional): [description]. Defaults to "list".
|
| 123 |
+
verbose (bool, optional): [description]. Defaults to False.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
[type]: [description]
|
| 127 |
+
"""
|
| 128 |
+
appr_files = []
|
| 129 |
+
# r=root, d=directories, f = files
|
| 130 |
+
for r, d, f in os.walk(directory):
|
| 131 |
+
for prefile in f:
|
| 132 |
+
if prefile.endswith(req_extension):
|
| 133 |
+
fullpath = os.path.join(r, prefile)
|
| 134 |
+
appr_files.append(fullpath)
|
| 135 |
+
|
| 136 |
+
appr_files = natsorted(appr_files)
|
| 137 |
+
|
| 138 |
+
if verbose:
|
| 139 |
+
print("A list of files in the {} directory are: \n".format(directory))
|
| 140 |
+
if len(appr_files) < 10:
|
| 141 |
+
pp.pprint(appr_files)
|
| 142 |
+
else:
|
| 143 |
+
pp.pprint(appr_files[:10])
|
| 144 |
+
print("\n and more. There are a total of {} files".format(len(appr_files)))
|
| 145 |
+
|
| 146 |
+
if return_type.lower() == "list":
|
| 147 |
+
return appr_files
|
| 148 |
+
else:
|
| 149 |
+
if verbose:
|
| 150 |
+
print("returning dictionary")
|
| 151 |
+
|
| 152 |
+
appr_file_dict = {}
|
| 153 |
+
for this_file in appr_files:
|
| 154 |
+
appr_file_dict[basename(this_file)] = this_file
|
| 155 |
+
|
| 156 |
+
return appr_file_dict
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def URL_string_filter(text):
|
| 160 |
+
"""
|
| 161 |
+
URL_string_filter - filter out nonstandard "text" characters
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
text ([type]): [description]
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
[str]: [description]
|
| 168 |
+
"""
|
| 169 |
+
custom_printable = (
|
| 170 |
+
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ._"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
filtered = "".join((filter(lambda i: i in custom_printable, text)))
|
| 174 |
+
|
| 175 |
+
return filtered
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def getFilename_fromCd(cd):
|
| 179 |
+
if not cd:
|
| 180 |
+
return None
|
| 181 |
+
fname = re.findall("filename=(.+)", cd)
|
| 182 |
+
if len(fname) > 0:
|
| 183 |
+
output = fname[0]
|
| 184 |
+
elif cd.find("/"):
|
| 185 |
+
possible_fname = cd.rsplit("/", 1)[1]
|
| 186 |
+
output = URL_string_filter(possible_fname)
|
| 187 |
+
else:
|
| 188 |
+
output = None
|
| 189 |
+
return output
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_zip_URL(
|
| 193 |
+
URLtoget: str,
|
| 194 |
+
extract_loc: str = None,
|
| 195 |
+
file_header: str = "dropboxexport_",
|
| 196 |
+
verbose: bool = False,
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
get_zip_URL [summary]
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
URLtoget (str): [description]
|
| 203 |
+
extract_loc (str, optional): [description]. Defaults to None.
|
| 204 |
+
file_header (str, optional): [description]. Defaults to "dropboxexport_".
|
| 205 |
+
verbose (bool, optional): [description]. Defaults to False.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
[type]: [description]
|
| 209 |
+
"""
|
| 210 |
+
r = requests.get(URLtoget, allow_redirects=True)
|
| 211 |
+
names = getFilename_fromCd(r.headers.get("content-disposition"))
|
| 212 |
+
fixed_fnames = names.split(";") # split the multiple results
|
| 213 |
+
this_filename = file_header + URL_string_filter(fixed_fnames[0])
|
| 214 |
+
|
| 215 |
+
# define paths and save the zip file
|
| 216 |
+
if extract_loc is None:
|
| 217 |
+
extract_loc = "dropbox_dl"
|
| 218 |
+
dl_place = join(os.getcwd(), extract_loc)
|
| 219 |
+
create_folder(dl_place)
|
| 220 |
+
save_loc = join(os.getcwd(), this_filename)
|
| 221 |
+
open(save_loc, "wb").write(r.content)
|
| 222 |
+
if verbose:
|
| 223 |
+
print("downloaded file size was {} MB".format(getsize(save_loc) / 1000000))
|
| 224 |
+
|
| 225 |
+
# unpack the archive
|
| 226 |
+
shutil.unpack_archive(save_loc, extract_dir=dl_place)
|
| 227 |
+
if verbose:
|
| 228 |
+
print("extracted zip file - ", datetime.now())
|
| 229 |
+
x = load_dir_files(dl_place, req_extension="", verbose=verbose)
|
| 230 |
+
|
| 231 |
+
# remove original
|
| 232 |
+
try:
|
| 233 |
+
os.remove(save_loc)
|
| 234 |
+
del save_loc
|
| 235 |
+
except:
|
| 236 |
+
print("unable to delete original zipfile - check if exists", datetime.now())
|
| 237 |
+
|
| 238 |
+
print("finished extracting zip - ", datetime.now())
|
| 239 |
+
|
| 240 |
+
return dl_place
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def merge_dataframes(data_dir: str, ext=".xlsx", verbose=False):
|
| 244 |
+
"""
|
| 245 |
+
merge_dataframes - given a filepath, loads and attempts to merge all files as dataframes
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
data_dir (str): [root directory to search in]
|
| 249 |
+
ext (str, optional): [anticipate file extension for the dataframes ]. Defaults to '.xlsx'.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
pd.DataFrame(): merged dataframe
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
src = Path(data_dir)
|
| 256 |
+
src_str = str(src.resolve())
|
| 257 |
+
mrg_df = pd.DataFrame()
|
| 258 |
+
|
| 259 |
+
all_reports = load_dir_files(directory=src_str, req_extension=ext, verbose=verbose)
|
| 260 |
+
|
| 261 |
+
failed = []
|
| 262 |
+
|
| 263 |
+
for df_path in tqdm(all_reports, total=len(all_reports), desc="joining data..."):
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
this_df = pd.read_excel(df_path).convert_dtypes()
|
| 267 |
+
|
| 268 |
+
mrg_df = pd.concat([mrg_df, this_df], axis=0)
|
| 269 |
+
except:
|
| 270 |
+
short_p = os.path.basename(df_path)
|
| 271 |
+
print(
|
| 272 |
+
f"WARNING - file with extension {ext} and name {short_p} could not be read."
|
| 273 |
+
)
|
| 274 |
+
failed.append(short_p)
|
| 275 |
+
|
| 276 |
+
if len(failed) > 0:
|
| 277 |
+
print("failed to merge {} files, investigate as needed")
|
| 278 |
+
|
| 279 |
+
if verbose:
|
| 280 |
+
pp.pprint(mrg_df.info(True))
|
| 281 |
+
|
| 282 |
+
return mrg_df
|