Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Author: Rico Sennrich | |
| """Use operations learned with learn_bpe.py to encode a new text. | |
| The text will not be smaller, but use only a fixed vocabulary, with rare words | |
| encoded as variable-length sequences of subword units. | |
| Reference: | |
| Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units. | |
| Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany. | |
| """ | |
| from __future__ import unicode_literals, division | |
| import sys | |
| import os | |
| import inspect | |
| import codecs | |
| import io | |
| import argparse | |
| import re | |
| import warnings | |
| import random | |
| import tempfile | |
| from multiprocessing import Pool, cpu_count | |
| # hack for python2/3 compatibility | |
| from io import open | |
| argparse.open = open | |
| class BPE(object): | |
| def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None): | |
| codes.seek(0) | |
| offset=1 | |
| # check version information | |
| firstline = codes.readline() | |
| if firstline.startswith('#version:'): | |
| self.version = tuple([int(x) for x in re.sub(r'(\.0+)*$','', firstline.split()[-1]).split(".")]) | |
| offset += 1 | |
| else: | |
| self.version = (0, 1) | |
| codes.seek(0) | |
| self.bpe_codes = [tuple(item.strip('\r\n ').split(' ')) for (n, item) in enumerate(codes.read().rstrip('\n').split('\n')) if (n < merges or merges == -1)] | |
| for i, item in enumerate(self.bpe_codes): | |
| if len(item) != 2: | |
| sys.stderr.write('Error: invalid line {0} in BPE codes file: {1}\n'.format(i+offset, ' '.join(item))) | |
| sys.stderr.write('The line should exist of exactly two subword units, separated by whitespace\n') | |
| sys.exit(1) | |
| # some hacking to deal with duplicates (only consider first instance) | |
| self.bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(self.bpe_codes)))]) | |
| self.bpe_codes_reverse = dict([(pair[0] + pair[1], pair) for pair,i in self.bpe_codes.items()]) | |
| self.separator = separator | |
| self.vocab = vocab | |
| self.glossaries = glossaries if glossaries else [] | |
| self.glossaries_regex = re.compile('^({})$'.format('|'.join(glossaries))) if glossaries else None | |
| self.cache = {} | |
| def process_lines(self, filename, outfile, dropout=0, num_workers=1): | |
| if sys.version_info < (3, 0): | |
| print("Parallel mode is only supported in Python3.") | |
| sys.exit(1) | |
| if num_workers == 1: | |
| _process_lines(self, filename, outfile, dropout, 0, 0) | |
| elif num_workers > 1: | |
| with open(filename, encoding="utf-8") as f: | |
| size = os.fstat(f.fileno()).st_size | |
| chunk_size = int(size / num_workers) | |
| offsets = [0 for _ in range(num_workers + 1)] | |
| for i in range(1, num_workers): | |
| f.seek(chunk_size * i) | |
| pos = f.tell() | |
| while True: | |
| try: | |
| line = f.readline() | |
| break | |
| except UnicodeDecodeError: | |
| pos -= 1 | |
| f.seek(pos) | |
| offsets[i] = f.tell() | |
| assert 0 <= offsets[i] < 1e20, "Bad new line separator, e.g. '\\r'" | |
| res_files = [] | |
| pool = Pool(processes=num_workers) | |
| for i in range(num_workers): | |
| tmp = tempfile.NamedTemporaryFile(delete=False) | |
| tmp.close() | |
| res_files.append(tmp) | |
| pool.apply_async(_process_lines, (self, filename, tmp.name, dropout, offsets[i], offsets[i + 1])) | |
| pool.close() | |
| pool.join() | |
| for i in range(num_workers): | |
| with open(res_files[i].name, encoding="utf-8") as fi: | |
| for line in fi: | |
| outfile.write(line) | |
| os.remove(res_files[i].name) | |
| else: | |
| raise ValueError('`num_workers` is expected to be a positive number, but got {}.'.format(num_workers)) | |
| def process_line(self, line, dropout=0): | |
| """segment line, dealing with leading and trailing whitespace""" | |
| out = "" | |
| leading_whitespace = len(line)-len(line.lstrip('\r\n ')) | |
| if leading_whitespace: | |
| out += line[:leading_whitespace] | |
| out += self.segment(line, dropout) | |
| trailing_whitespace = len(line)-len(line.rstrip('\r\n ')) | |
| if trailing_whitespace and trailing_whitespace != len(line): | |
| out += line[-trailing_whitespace:] | |
| return out | |
| def segment(self, sentence, dropout=0): | |
| """segment single sentence (whitespace-tokenized string) with BPE encoding""" | |
| segments = self.segment_tokens(sentence.strip('\r\n ').split(' '), dropout) | |
| return ' '.join(segments) | |
| def segment_tokens(self, tokens, dropout=0): | |
| """segment a sequence of tokens with BPE encoding""" | |
| output = [] | |
| for word in tokens: | |
| # eliminate double spaces | |
| if not word: | |
| continue | |
| new_word = [out for segment in self._isolate_glossaries(word) | |
| for out in encode(segment, | |
| self.bpe_codes, | |
| self.bpe_codes_reverse, | |
| self.vocab, | |
| self.separator, | |
| self.version, | |
| self.cache, | |
| self.glossaries_regex, | |
| dropout)] | |
| for item in new_word[:-1]: | |
| output.append(item + self.separator) | |
| output.append(new_word[-1]) | |
| return output | |
| def _isolate_glossaries(self, word): | |
| word_segments = [word] | |
| for gloss in self.glossaries: | |
| word_segments = [out_segments for segment in word_segments | |
| for out_segments in isolate_glossary(segment, gloss)] | |
| return word_segments | |
| def _process_lines(bpe, filename, outfile, dropout, begin, end): | |
| if isinstance(outfile, str): | |
| fo = open(outfile, "w", encoding="utf-8") | |
| else: | |
| fo = outfile | |
| with open(filename, encoding="utf-8") as f: | |
| f.seek(begin) | |
| line = f.readline() | |
| while line: | |
| pos = f.tell() | |
| assert 0 <= pos < 1e20, "Bad new line separator, e.g. '\\r'" | |
| if end > 0 and pos > end: | |
| break | |
| fo.write(bpe.process_line(line, dropout)) | |
| line = f.readline() | |
| if isinstance(outfile, str): | |
| fo.close() | |
| def create_parser(subparsers=None): | |
| if subparsers: | |
| parser = subparsers.add_parser('apply-bpe', | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| description="learn BPE-based word segmentation") | |
| else: | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| description="learn BPE-based word segmentation") | |
| parser.add_argument( | |
| '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, | |
| metavar='PATH', | |
| help="Input file (default: standard input).") | |
| parser.add_argument( | |
| '--codes', '-c', type=argparse.FileType('r'), metavar='PATH', | |
| required=True, | |
| help="File with BPE codes (created by learn_bpe.py).") | |
| parser.add_argument( | |
| '--merges', '-m', type=int, default=-1, | |
| metavar='INT', | |
| help="Use this many BPE operations (<= number of learned symbols)"+ | |
| "default: Apply all the learned merge operations") | |
| parser.add_argument( | |
| '--output', '-o', type=argparse.FileType('w'), default=sys.stdout, | |
| metavar='PATH', | |
| help="Output file (default: standard output)") | |
| parser.add_argument( | |
| '--separator', '-s', type=str, default='@@', metavar='STR', | |
| help="Separator between non-final subword units (default: '%(default)s'))") | |
| parser.add_argument( | |
| '--vocabulary', type=argparse.FileType('r'), default=None, | |
| metavar="PATH", | |
| help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.") | |
| parser.add_argument( | |
| '--vocabulary-threshold', type=int, default=None, | |
| metavar="INT", | |
| help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV") | |
| parser.add_argument( | |
| '--dropout', type=float, default=0, | |
| metavar="P", | |
| help="Dropout BPE merge operations with probability P (Provilkov et al., 2019). Use this on training data only.") | |
| parser.add_argument( | |
| '--glossaries', type=str, nargs='+', default=None, | |
| metavar="STR", | |
| help="Glossaries. Words matching any of the words/regex provided in glossaries will not be affected "+ | |
| "by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords. "+ | |
| "Can be provided as a list of words/regex after the --glossaries argument. Enclose each regex in quotes.") | |
| parser.add_argument( | |
| '--seed', type=int, default=None, | |
| metavar="S", | |
| help="Random seed for the random number generators (e.g. for BPE dropout with --dropout).") | |
| parser.add_argument( | |
| '--num-workers', type=int, default=1, | |
| help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)") | |
| return parser | |
| def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries_regex=None, dropout=0): | |
| """Encode word based on list of BPE merge operations, which are applied consecutively | |
| """ | |
| if not dropout and orig in cache: | |
| return cache[orig] | |
| if glossaries_regex and glossaries_regex.match(orig): | |
| cache[orig] = (orig,) | |
| return (orig,) | |
| if len(orig) == 1: | |
| return orig | |
| if version == (0, 1): | |
| word = list(orig) + ['</w>'] | |
| elif version == (0, 2): # more consistent handling of word-final segments | |
| word = list(orig[:-1]) + [orig[-1] + '</w>'] | |
| else: | |
| raise NotImplementedError | |
| while len(word) > 1: | |
| # get list of symbol pairs; optionally apply dropout | |
| pairs = [(bpe_codes[pair],i,pair) for (i,pair) in enumerate(zip(word, word[1:])) if (not dropout or random.random() > dropout) and pair in bpe_codes] | |
| if not pairs: | |
| break | |
| #get first merge operation in list of BPE codes | |
| bigram = min(pairs)[2] | |
| # find start position of all pairs that we want to merge | |
| positions = [i for (rank,i,pair) in pairs if pair == bigram] | |
| i = 0 | |
| new_word = [] | |
| bigram = ''.join(bigram) | |
| for j in positions: | |
| # merges are invalid if they start before current position. This can happen if there are overlapping pairs: (x x x -> xx x) | |
| if j < i: | |
| continue | |
| new_word.extend(word[i:j]) # all symbols before merged pair | |
| new_word.append(bigram) # merged pair | |
| i = j+2 # continue after merged pair | |
| new_word.extend(word[i:]) # add all symbols until end of word | |
| word = new_word | |
| # don't print end-of-word symbols | |
| if word[-1] == '</w>': | |
| word = word[:-1] | |
| elif word[-1].endswith('</w>'): | |
| word[-1] = word[-1][:-4] | |
| word = tuple(word) | |
| if vocab: | |
| word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator) | |
| cache[orig] = word | |
| return word | |
| def recursive_split(segment, bpe_codes, vocab, separator, final=False): | |
| """Recursively split segment into smaller units (by reversing BPE merges) | |
| until all units are either in-vocabulary, or cannot be split futher.""" | |
| try: | |
| if final: | |
| left, right = bpe_codes[segment + '</w>'] | |
| right = right[:-4] | |
| else: | |
| left, right = bpe_codes[segment] | |
| except: | |
| #sys.stderr.write('cannot split {0} further.\n'.format(segment)) | |
| yield segment | |
| return | |
| if left + separator in vocab: | |
| yield left | |
| else: | |
| for item in recursive_split(left, bpe_codes, vocab, separator, False): | |
| yield item | |
| if (final and right in vocab) or (not final and right + separator in vocab): | |
| yield right | |
| else: | |
| for item in recursive_split(right, bpe_codes, vocab, separator, final): | |
| yield item | |
| def check_vocab_and_split(orig, bpe_codes, vocab, separator): | |
| """Check for each segment in word if it is in-vocabulary, | |
| and segment OOV segments into smaller units by reversing the BPE merge operations""" | |
| out = [] | |
| for segment in orig[:-1]: | |
| if segment + separator in vocab: | |
| out.append(segment) | |
| else: | |
| #sys.stderr.write('OOV: {0}\n'.format(segment)) | |
| for item in recursive_split(segment, bpe_codes, vocab, separator, False): | |
| out.append(item) | |
| segment = orig[-1] | |
| if segment in vocab: | |
| out.append(segment) | |
| else: | |
| #sys.stderr.write('OOV: {0}\n'.format(segment)) | |
| for item in recursive_split(segment, bpe_codes, vocab, separator, True): | |
| out.append(item) | |
| return out | |
| def read_vocabulary(vocab_file, threshold): | |
| """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold. | |
| """ | |
| vocabulary = set() | |
| for line in vocab_file: | |
| word, freq = line.strip('\r\n ').split(' ') | |
| freq = int(freq) | |
| if threshold == None or freq >= threshold: | |
| vocabulary.add(word) | |
| return vocabulary | |
| def isolate_glossary(word, glossary): | |
| """ | |
| Isolate a glossary present inside a word. | |
| Returns a list of subwords. In which all 'glossary' glossaries are isolated | |
| For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is: | |
| ['1934', 'USA', 'B', 'USA'] | |
| """ | |
| # regex equivalent of (if word == glossary or glossary not in word) | |
| if re.match('^'+glossary+'$', word) or not re.search(glossary, word): | |
| return [word] | |
| else: | |
| segments = re.split(r'({})'.format(glossary), word) | |
| segments, ending = segments[:-1], segments[-1] | |
| segments = list(filter(None, segments)) # Remove empty strings in regex group. | |
| return segments + [ending.strip('\r\n ')] if ending != '' else segments | |
| if __name__ == '__main__': | |
| currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
| newdir = os.path.join(currentdir, 'subword_nmt') | |
| if os.path.isdir(newdir): | |
| warnings.simplefilter('default') | |
| warnings.warn( | |
| "this script's location has moved to {0}. This symbolic link will be removed in a future version. Please point to the new location, or install the package and use the command 'subword-nmt'".format(newdir), | |
| DeprecationWarning | |
| ) | |
| # python 2/3 compatibility | |
| if sys.version_info < (3, 0): | |
| sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) | |
| sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) | |
| sys.stdin = codecs.getreader('UTF-8')(sys.stdin) | |
| else: | |
| sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') | |
| sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') | |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True) | |
| parser = create_parser() | |
| args = parser.parse_args() | |
| if args.num_workers <= 0: | |
| args.num_workers = cpu_count() | |
| # read/write files as UTF-8 | |
| args.codes = codecs.open(args.codes.name, encoding='utf-8') | |
| if args.input.name != '<stdin>': | |
| args.input = codecs.open(args.input.name, encoding='utf-8') | |
| if args.output.name != '<stdout>': | |
| args.output = codecs.open(args.output.name, 'w', encoding='utf-8') | |
| if args.vocabulary: | |
| args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8') | |
| if args.vocabulary: | |
| vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold) | |
| else: | |
| vocabulary = None | |
| if sys.version_info < (3, 0): | |
| args.separator = args.separator.decode('UTF-8') | |
| if args.glossaries: | |
| args.glossaries = [g.decode('UTF-8') for g in args.glossaries] | |
| if args.num_workers > 1: | |
| args.num_workers = 1 | |
| warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.") | |
| if args.seed is not None: | |
| random.seed(args.seed) | |
| bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries) | |
| if args.input.name == '<stdin>' or args.num_workers == 1: | |
| if args.num_workers > 1: | |
| warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.") | |
| for line in args.input: | |
| args.output.write(bpe.process_line(line, args.dropout)) | |
| else: | |
| bpe.process_lines(args.input.name, args.output, args.dropout, args.num_workers) | |