[#1] main_infer.py implemented
Browse files- idiomify/fetchers.py +1 -2
- idiomify/idiomifier.py +22 -0
- idiomify/models.py +0 -8
- main_infer.py +28 -37
idiomify/fetchers.py
CHANGED
|
@@ -95,8 +95,7 @@ def fetch_alpha(ver: str, run: Run = None) -> Alpha:
|
|
| 95 |
artifact_dir = artifact.download(root=alpha_dir(ver))
|
| 96 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
| 97 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
| 98 |
-
|
| 99 |
-
alpha = Alpha.load_from_checkpoint(ckpt_path, bart=bart)
|
| 100 |
return alpha
|
| 101 |
|
| 102 |
|
|
|
|
| 95 |
artifact_dir = artifact.download(root=alpha_dir(ver))
|
| 96 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
| 97 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
| 98 |
+
alpha = Alpha.load_from_checkpoint(ckpt_path, bart=bart)
|
|
|
|
| 99 |
return alpha
|
| 100 |
|
| 101 |
|
idiomify/idiomifier.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import BartTokenizer
|
| 2 |
+
from builders import SourcesBuilder
|
| 3 |
+
from models import Alpha
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Idiomifier:
|
| 7 |
+
|
| 8 |
+
def __init__(self, model: Alpha, tokenizer: BartTokenizer):
|
| 9 |
+
self.model = model
|
| 10 |
+
self.builder = SourcesBuilder(tokenizer)
|
| 11 |
+
self.model.eval()
|
| 12 |
+
|
| 13 |
+
def __call__(self, src: str, max_length=100) -> str:
|
| 14 |
+
srcs = self.builder(literal2idiomatic=[(src, "")])
|
| 15 |
+
pred_ids = self.model.bart.generate(
|
| 16 |
+
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
| 17 |
+
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
| 18 |
+
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
| 19 |
+
max_length=max_length,
|
| 20 |
+
).squeeze() # -> (N, L_t) -> (L_t)
|
| 21 |
+
tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
| 22 |
+
return tgt
|
idiomify/models.py
CHANGED
|
@@ -47,14 +47,6 @@ class Alpha(pl.LightningModule): # noqa
|
|
| 47 |
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
| 48 |
self.log("Train/Loss", outputs['loss'])
|
| 49 |
|
| 50 |
-
def predict(self, srcs: torch.Tensor) -> torch.Tensor:
|
| 51 |
-
pred_ids = self.bart.generate(
|
| 52 |
-
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
| 53 |
-
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
| 54 |
-
decoder_start_token_id=self.hparams['bos_token_id'],
|
| 55 |
-
)
|
| 56 |
-
return pred_ids # (N, L)
|
| 57 |
-
|
| 58 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
| 59 |
"""
|
| 60 |
Instantiates and returns the optimizer to be used for this model
|
|
|
|
| 47 |
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
| 48 |
self.log("Train/Loss", outputs['loss'])
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
| 51 |
"""
|
| 52 |
Instantiates and returns the optimizer to be used for this model
|
main_infer.py
CHANGED
|
@@ -1,37 +1,28 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# # sort and append
|
| 30 |
-
# res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
|
| 31 |
-
# print(f"query: {colored(text=config['sent'], color='blue')}")
|
| 32 |
-
# for idx, (idiom, prob) in enumerate(res):
|
| 33 |
-
# print(idx, idiom, prob)
|
| 34 |
-
#
|
| 35 |
-
#
|
| 36 |
-
# if __name__ == '__main__':
|
| 37 |
-
# main()
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from termcolor import colored
|
| 3 |
+
from idiomifier import Idiomifier
|
| 4 |
+
from idiomify.fetchers import fetch_config, fetch_alpha
|
| 5 |
+
from transformers import BartTokenizer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument("--model", type=str,
|
| 11 |
+
default="alpha")
|
| 12 |
+
parser.add_argument("--ver", type=str,
|
| 13 |
+
default="overfit")
|
| 14 |
+
parser.add_argument("--src", type=str,
|
| 15 |
+
default="If there's any benefits to losing my job, it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
config = fetch_config()[args.model][args.ver]
|
| 18 |
+
config.update(vars(args))
|
| 19 |
+
model = fetch_alpha(config['ver'])
|
| 20 |
+
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
| 21 |
+
idiomifier = Idiomifier(model, tokenizer)
|
| 22 |
+
src = config['src']
|
| 23 |
+
tgt = idiomifier(src=config['src'])
|
| 24 |
+
print(src, "\n->", colored(tgt, "blue"))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == '__main__':
|
| 28 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|