Spaces:
Runtime error
Runtime error
| import re | |
| import os | |
| import sys | |
| from tqdm import tqdm | |
| def remove_bpe(line, bpe_symbol="@@ "): | |
| line = line.replace("\n", '') | |
| line = (line + ' ').replace(bpe_symbol, '').rstrip() | |
| return line + ("\n") | |
| def remove_bpe_fn(i=sys.stdin, o=sys.stdout, bpe="@@ "): | |
| lines = tqdm(i) | |
| lines = map(lambda x: remove_bpe(x, bpe), lines) | |
| # _write_lines(lines, f=o) | |
| for line in lines: | |
| o.write(line) | |
| def reprocess(fle): | |
| # takes in a file of generate.py translation generate_output | |
| # returns a source dict and hypothesis dict, where keys are the ID num (as a string) | |
| # and values and the corresponding source and translation. There may be several translations | |
| # per source, so the values for hypothesis_dict are lists. | |
| # parses output of generate.py | |
| with open(fle, 'r') as f: | |
| txt = f.read() | |
| """reprocess generate.py output""" | |
| p = re.compile(r"[STHP][-]\d+\s*") | |
| hp = re.compile(r"(\s*[-]?\d+[.]?\d+(e[+-]?\d+)?\s*)|(\s*(-inf)\s*)") | |
| source_dict = {} | |
| hypothesis_dict = {} | |
| score_dict = {} | |
| target_dict = {} | |
| pos_score_dict = {} | |
| lines = txt.split("\n") | |
| for line in lines: | |
| line += "\n" | |
| prefix = re.search(p, line) | |
| if prefix is not None: | |
| assert len(prefix.group()) > 2, "prefix id not found" | |
| _, j = prefix.span() | |
| id_num = prefix.group()[2:] | |
| id_num = int(id_num) | |
| line_type = prefix.group()[0] | |
| if line_type == "H": | |
| h_txt = line[j:] | |
| hypo = re.search(hp, h_txt) | |
| assert hypo is not None, ("regular expression failed to find the hypothesis scoring") | |
| _, i = hypo.span() | |
| score = hypo.group() | |
| hypo_str = h_txt[i:] | |
| # if r2l: # todo: reverse score as well | |
| # hypo_str = " ".join(reversed(hypo_str.strip().split(" "))) + "\n" | |
| if id_num in hypothesis_dict: | |
| hypothesis_dict[id_num].append(hypo_str) | |
| score_dict[id_num].append(float(score)) | |
| else: | |
| hypothesis_dict[id_num] = [hypo_str] | |
| score_dict[id_num] = [float(score)] | |
| elif line_type == "S": | |
| source_dict[id_num] = (line[j:]) | |
| elif line_type == "T": | |
| # target_dict[id_num] = (line[j:]) | |
| continue | |
| elif line_type == "P": | |
| pos_scores = (line[j:]).split() | |
| pos_scores = [float(x) for x in pos_scores] | |
| if id_num in pos_score_dict: | |
| pos_score_dict[id_num].append(pos_scores) | |
| else: | |
| pos_score_dict[id_num] = [pos_scores] | |
| return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict | |
| def get_hypo_and_ref(fle, hyp_file, ref_input, ref_file, rank=0): | |
| with open(ref_input, 'r') as f: | |
| refs = f.readlines() | |
| _, hypo_dict, _, _, _ = reprocess(fle) | |
| assert rank < len(hypo_dict[0]) | |
| maxkey = max(hypo_dict, key=int) | |
| f_hyp = open(hyp_file, "w") | |
| f_ref = open(ref_file, "w") | |
| for idx in range(maxkey + 1): | |
| if idx not in hypo_dict: | |
| continue | |
| f_hyp.write(hypo_dict[idx][rank]) | |
| f_ref.write(refs[idx]) | |
| f_hyp.close() | |
| f_ref.close() | |
| def recover_bpe(hyp_file): | |
| f_hyp = open(hyp_file, "r") | |
| f_hyp_out = open(hyp_file + ".nobpe", "w") | |
| for _s in ["hyp"]: | |
| f = eval("f_{}".format(_s)) | |
| fout = eval("f_{}_out".format(_s)) | |
| remove_bpe_fn(i=f, o=fout) | |
| f_hyp.close() | |
| f_hyp_out.close() | |
| if __name__ == "__main__": | |
| filename = sys.argv[1] | |
| ref_in = sys.argv[2] | |
| hypo_file = os.path.join(os.path.dirname(filename), "hypo.out") | |
| ref_out = os.path.join(os.path.dirname(filename), "ref.out") | |
| get_hypo_and_ref(filename, hypo_file, ref_in, ref_out) | |
| recover_bpe(hypo_file) | |