Spaces:
Runtime error
Runtime error
| #mRASP2 tester. | |
| import ast | |
| import fileinput | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| import time | |
| from argparse import Namespace | |
| from collections import namedtuple | |
| import pickle | |
| from omegaconf import DictConfig, OmegaConf | |
| import numpy as np | |
| import torch | |
| os.system("git clone https://github.com/PANXiao1994/mRASP2.git") | |
| os.system('mv -n mRASP2/* ./') | |
| os.system("rm -rf mRASP2") | |
| os.system("pip install -r requirements.txt") | |
| os.system("git clone https://github.com/pytorch/fairseq") | |
| os.system("cd fairseq; pip install ./; cd ..") | |
| model_name = "12e12d_last.pt" | |
| os.system("wget https://lf3-nlp-opensource.bytetos.com/obj/nlp-opensource/acl2021/mrasp2/bpe_vocab") | |
| os.system("wget https://lf3-nlp-opensource.bytetos.com/obj/nlp-opensource/emnlp2020/mrasp/pretrain/dataset/codes.bpe.32000") | |
| from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils | |
| from fairseq.dataclass.configs import FairseqConfig | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from fairseq.token_generation_constraints import pack_constraints, unpack_constraints | |
| from fairseq_cli.generate import get_symbols_to_strip_from_output | |
| from fairseq_cli.interactive import buffered_read, make_batches | |
| logging.basicConfig( | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
| stream=sys.stdout, | |
| ) | |
| logger = logging.getLogger("fairseq_cli.interactive") | |
| Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") | |
| Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") | |
| def createCFG() -> DictConfig: | |
| os.environ['PYTHONPATH'] = "/home/chinmay/.local/lib/python3.10/site-packages" | |
| with open("cfg.txt", "rb") as reader: | |
| cfg_binary = reader.read() | |
| cfg_dict = pickle.loads(cfg_binary) | |
| cfg = OmegaConf.create(cfg_dict) | |
| return cfg | |
| def loadmRASP2(cfg): | |
| if isinstance(cfg, Namespace): | |
| cfg = convert_namespace_to_omegaconf(cfg) | |
| start_time = time.time() | |
| utils.import_user_module(cfg.common) | |
| if cfg.interactive.buffer_size < 1: | |
| cfg.interactive.buffer_size = 1 | |
| if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: | |
| cfg.dataset.batch_size = 1 | |
| assert ( | |
| not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam | |
| ), "--sampling requires --nbest to be equal to --beam" | |
| assert ( | |
| not cfg.dataset.batch_size | |
| or cfg.dataset.batch_size <= cfg.interactive.buffer_size | |
| ), "--batch-size cannot be larger than --buffer-size" | |
| logger.info(cfg) | |
| # Fix seed for stochastic decoding | |
| if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
| np.random.seed(cfg.common.seed) | |
| utils.set_torch_seed(cfg.common.seed) | |
| use_cuda = torch.cuda.is_available() and not cfg.common.cpu | |
| # Setup task, e.g., translation | |
| task = tasks.setup_task(cfg.task) | |
| # Load ensemble | |
| overrides = ast.literal_eval(cfg.common_eval.model_overrides) | |
| logger.info("loading model(s) from {}".format(cfg.common_eval.path)) | |
| models, _model_args = checkpoint_utils.load_model_ensemble( | |
| utils.split_paths(cfg.common_eval.path), | |
| arg_overrides=overrides, | |
| task=task, | |
| suffix=cfg.checkpoint.checkpoint_suffix, | |
| strict=(cfg.checkpoint.checkpoint_shard_count == 1), | |
| num_shards=cfg.checkpoint.checkpoint_shard_count, | |
| ) | |
| print(cfg) | |
| # Set dictionaries | |
| src_dict = task.source_dictionary | |
| tgt_dict = task.target_dictionary | |
| # Optimize ensemble for generation | |
| for model in models: | |
| if model is None: | |
| continue | |
| if cfg.common.fp16: | |
| model.half() | |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
| model.cuda() | |
| model.prepare_for_inference_(cfg) | |
| # Initialize generator | |
| generator = task.build_generator(models, cfg.generation) | |
| # Handle tokenization and BPE | |
| tokenizer = task.build_tokenizer(cfg.tokenizer) | |
| bpe = task.build_bpe(cfg.bpe) | |
| # Load alignment dictionary for unknown word replacement | |
| # (None if no unknown word replacement, empty if no path to align dictionary) | |
| align_dict = utils.load_align_dict(cfg.generation.replace_unk) | |
| max_positions = utils.resolve_max_positions( | |
| task.max_positions(), *[model.max_positions() for model in models] | |
| ) | |
| if cfg.generation.constraints: | |
| logger.warning( | |
| "NOTE: Constrained decoding currently assumes a shared subword vocabulary." | |
| ) | |
| if cfg.interactive.buffer_size > 1: | |
| logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) | |
| logger.info("NOTE: hypothesis and token scores are output in base 2") | |
| logger.info("Type the input sentence and press return:") | |
| start_id = 0 | |
| return models, task, max_positions, tokenizer, bpe, use_cuda, generator, src_dict, tgt_dict, align_dict, start_time, start_id | |
| def infer(cfg, models, task, max_positions, tokenizer, bpe, use_cuda, generator, src_dict, tgt_dict, align_dict, start_time, start_id, src_lang, tgt_lang): | |
| def encode_fn(x): | |
| if tokenizer is not None: | |
| x = tokenizer.encode(x) | |
| if bpe is not None: | |
| x = bpe.encode(x) | |
| return x | |
| def decode_fn(x): | |
| if bpe is not None: | |
| x = bpe.decode(x) | |
| if tokenizer is not None: | |
| x = tokenizer.decode(x) | |
| return x | |
| cfg.interactive.input = "input." + str(src_lang) | |
| cfg.task.input = "input." + str(src_lang) | |
| cfg.task.lang_prefix_tok = "LANG_TOK_" + str(tgt_lang).upper() | |
| total_translate_time = 0 | |
| for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): | |
| results = [] | |
| for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): | |
| bsz = batch.src_tokens.size(0) | |
| src_tokens = batch.src_tokens | |
| src_lengths = batch.src_lengths | |
| constraints = batch.constraints | |
| if use_cuda: | |
| src_tokens = src_tokens.cuda() | |
| src_lengths = src_lengths.cuda() | |
| if constraints is not None: | |
| constraints = constraints.cuda() | |
| sample = { | |
| "net_input": { | |
| "src_tokens": src_tokens, | |
| "src_lengths": src_lengths, | |
| }, | |
| } | |
| translate_start_time = time.time() | |
| translations = task.inference_step( | |
| generator, models, sample, constraints=constraints | |
| ) | |
| translate_time = time.time() - translate_start_time | |
| total_translate_time += translate_time | |
| list_constraints = [[] for _ in range(bsz)] | |
| if cfg.generation.constraints: | |
| list_constraints = [unpack_constraints(c) for c in constraints] | |
| for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): | |
| src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) | |
| constraints = list_constraints[i] | |
| results.append( | |
| ( | |
| start_id + id, | |
| src_tokens_i, | |
| hypos, | |
| { | |
| "constraints": constraints, | |
| "time": translate_time / len(translations), | |
| }, | |
| ) | |
| ) | |
| # sort output to match input order | |
| for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): | |
| src_str = "" | |
| if src_dict is not None: | |
| src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) | |
| print("S-{}\t{}".format(id_, src_str)) | |
| print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) | |
| for constraint in info["constraints"]: | |
| print( | |
| "C-{}\t{}".format( | |
| id_, | |
| tgt_dict.string(constraint, cfg.common_eval.post_process), | |
| ) | |
| ) | |
| # Process top predictions | |
| for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: | |
| hypo_tokens, hypo_str, alignment = utils.post_process_prediction( | |
| hypo_tokens=hypo["tokens"].int().cpu(), | |
| src_str=src_str, | |
| alignment=hypo["alignment"], | |
| align_dict=align_dict, | |
| tgt_dict=tgt_dict, | |
| remove_bpe=cfg.common_eval.post_process, | |
| extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), | |
| ) | |
| detok_hypo_str = decode_fn(hypo_str) | |
| with open("output", "w") as writer: | |
| writer.write(detok_hypo_str) | |
| score = hypo["score"] / math.log(2) # convert to base 2 | |
| # original hypothesis (after tokenization and BPE) | |
| print("H-{}\t{}\t{}".format(id_, score, hypo_str)) | |
| # detokenized hypothesis | |
| print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str)) | |
| print( | |
| "P-{}\t{}".format( | |
| id_, | |
| " ".join( | |
| map( | |
| lambda x: "{:.4f}".format(x), | |
| # convert from base e to base 2 | |
| hypo["positional_scores"].div_(math.log(2)).tolist(), | |
| ) | |
| ), | |
| ) | |
| ) | |
| if cfg.generation.print_alignment: | |
| alignment_str = " ".join( | |
| ["{}-{}".format(src, tgt) for src, tgt in alignment] | |
| ) | |
| print("A-{}\t{}".format(id_, alignment_str)) | |
| # update running id_ counter | |
| start_id += len(inputs) | |
| logger.info( | |
| "Total time: {:.3f} seconds; translation time: {:.3f}".format( | |
| time.time() - start_time, total_translate_time | |
| ) | |
| ) | |
| return detok_hypo_str |