Spaces:
Runtime error
Runtime error
🔊 add logs
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- constrained_generation.py +7 -1
- converse.py +1 -0
- utils.py +13 -0
constrained_generation.py
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
|
| 5 |
import copy
|
| 6 |
import logging
|
|
|
|
| 7 |
logging.basicConfig(level=logging.INFO)
|
| 8 |
import time
|
| 9 |
from pathlib import Path
|
|
@@ -11,6 +12,7 @@ from pathlib import Path
|
|
| 11 |
import yake
|
| 12 |
from transformers import AutoTokenizer, PhrasalConstraint
|
| 13 |
|
|
|
|
| 14 |
def get_tokenizer(model_name="gpt2", verbose=False):
|
| 15 |
"""
|
| 16 |
get_tokenizer - returns a tokenizer object
|
|
@@ -164,6 +166,8 @@ def constrained_generation(
|
|
| 164 |
-------
|
| 165 |
response : str, generated text
|
| 166 |
"""
|
|
|
|
|
|
|
| 167 |
st = time.perf_counter()
|
| 168 |
tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
|
| 169 |
tokenizer.add_prefix_space = True
|
|
@@ -228,7 +232,9 @@ def constrained_generation(
|
|
| 228 |
force_words_ids=force_words_ids if force_flexible is not None else None,
|
| 229 |
max_length=None,
|
| 230 |
max_new_tokens=max_generated_tokens,
|
| 231 |
-
min_length=min_generated_tokens + prompt_length
|
|
|
|
|
|
|
| 232 |
num_beams=num_beams,
|
| 233 |
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 234 |
num_return_sequences=num_return_sequences,
|
|
|
|
| 4 |
|
| 5 |
import copy
|
| 6 |
import logging
|
| 7 |
+
|
| 8 |
logging.basicConfig(level=logging.INFO)
|
| 9 |
import time
|
| 10 |
from pathlib import Path
|
|
|
|
| 12 |
import yake
|
| 13 |
from transformers import AutoTokenizer, PhrasalConstraint
|
| 14 |
|
| 15 |
+
|
| 16 |
def get_tokenizer(model_name="gpt2", verbose=False):
|
| 17 |
"""
|
| 18 |
get_tokenizer - returns a tokenizer object
|
|
|
|
| 166 |
-------
|
| 167 |
response : str, generated text
|
| 168 |
"""
|
| 169 |
+
logging.debug(f" constraining generation with {locals()}")
|
| 170 |
+
|
| 171 |
st = time.perf_counter()
|
| 172 |
tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
|
| 173 |
tokenizer.add_prefix_space = True
|
|
|
|
| 232 |
force_words_ids=force_words_ids if force_flexible is not None else None,
|
| 233 |
max_length=None,
|
| 234 |
max_new_tokens=max_generated_tokens,
|
| 235 |
+
min_length=min_generated_tokens + prompt_length
|
| 236 |
+
if full_text
|
| 237 |
+
else min_generated_tokens,
|
| 238 |
num_beams=num_beams,
|
| 239 |
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 240 |
num_return_sequences=num_return_sequences,
|
converse.py
CHANGED
|
@@ -186,6 +186,7 @@ def gen_response(
|
|
| 186 |
str, the generated text
|
| 187 |
|
| 188 |
"""
|
|
|
|
| 189 |
input_len = len(pipeline.tokenizer(query).input_ids)
|
| 190 |
if max_length + input_len > 1024:
|
| 191 |
max_length = max(1024 - input_len, 8)
|
|
|
|
| 186 |
str, the generated text
|
| 187 |
|
| 188 |
"""
|
| 189 |
+
logging.debug(f"input args - gen_response() : {locals()}")
|
| 190 |
input_len = len(pipeline.tokenizer(query).input_ids)
|
| 191 |
if max_length + input_len > 1024:
|
| 192 |
max_length = max(1024 - input_len, 8)
|
utils.py
CHANGED
|
@@ -7,6 +7,7 @@ from pathlib import Path
|
|
| 7 |
import pprint as pp
|
| 8 |
import re
|
| 9 |
import shutil # zipfile formats
|
|
|
|
| 10 |
from datetime import datetime
|
| 11 |
from os.path import basename
|
| 12 |
from os.path import getsize, join
|
|
@@ -383,3 +384,15 @@ def cleantxt_wrap(ugly_text, all_lower=False):
|
|
| 383 |
return clean(ugly_text, lower=all_lower)
|
| 384 |
else:
|
| 385 |
return ugly_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import pprint as pp
|
| 8 |
import re
|
| 9 |
import shutil # zipfile formats
|
| 10 |
+
import logging
|
| 11 |
from datetime import datetime
|
| 12 |
from os.path import basename
|
| 13 |
from os.path import getsize, join
|
|
|
|
| 384 |
return clean(ugly_text, lower=all_lower)
|
| 385 |
else:
|
| 386 |
return ugly_text
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def setup_logging(loglevel):
|
| 390 |
+
"""Setup basic logging
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
loglevel (int): minimum loglevel for emitting messages
|
| 394 |
+
"""
|
| 395 |
+
logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
|
| 396 |
+
logging.basicConfig(
|
| 397 |
+
level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S"
|
| 398 |
+
)
|