Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Author: Rico Sennrich | |
| """Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text. | |
| This script learns BPE jointly on a concatenation of a list of texts (typically the source and target side of a parallel corpus, | |
| applies the learned operation to each and (optionally) returns the resulting vocabulary of each text. | |
| The vocabulary can be used in apply_bpe.py to avoid producing symbols that are rare or OOV in a training text. | |
| Reference: | |
| Rico Sennrich, Barry Haddow and Alexandra Birch (2016). Neural Machine Translation of Rare Words with Subword Units. | |
| Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany. | |
| """ | |
| from __future__ import unicode_literals | |
| import sys | |
| import os | |
| import inspect | |
| import codecs | |
| import argparse | |
| import tempfile | |
| import warnings | |
| from collections import Counter | |
| from multiprocessing import cpu_count | |
| #hack to get imports working if running this as a script, or within a package | |
| if __name__ == '__main__': | |
| import learn_bpe | |
| import apply_bpe | |
| else: | |
| from . import learn_bpe | |
| from . import apply_bpe | |
| # hack for python2/3 compatibility | |
| from io import open | |
| argparse.open = open | |
| def create_parser(subparsers=None): | |
| if subparsers: | |
| parser = subparsers.add_parser('learn-joint-bpe-and-vocab', | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| description="learn BPE-based word segmentation") | |
| else: | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| description="learn BPE-based word segmentation") | |
| parser.add_argument( | |
| '--input', '-i', type=argparse.FileType('r'), required=True, nargs = '+', | |
| metavar='PATH', | |
| help="Input texts (multiple allowed).") | |
| parser.add_argument( | |
| '--output', '-o', type=argparse.FileType('w'), required=True, | |
| metavar='PATH', | |
| help="Output file for BPE codes.") | |
| parser.add_argument( | |
| '--symbols', '-s', type=int, default=10000, | |
| help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)") | |
| parser.add_argument( | |
| '--separator', type=str, default='@@', metavar='STR', | |
| help="Separator between non-final subword units (default: '%(default)s')") | |
| parser.add_argument( | |
| '--write-vocabulary', type=argparse.FileType('w'), required=True, nargs = '+', default=None, | |
| metavar='PATH', dest='vocab', | |
| help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py') | |
| parser.add_argument( | |
| '--min-frequency', type=int, default=2, metavar='FREQ', | |
| help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)') | |
| parser.add_argument( | |
| '--total-symbols', '-t', action="store_true", | |
| help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).") | |
| parser.add_argument( | |
| '--num-workers', type=int, default=1, | |
| help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)") | |
| parser.add_argument( | |
| '--verbose', '-v', action="store_true", | |
| help="verbose mode.") | |
| return parser | |
| def learn_joint_bpe_and_vocab(args): | |
| if args.vocab and len(args.input) != len(args.vocab): | |
| sys.stderr.write('Error: number of input files and vocabulary files must match\n') | |
| sys.exit(1) | |
| # read/write files as UTF-8 | |
| args.input = [codecs.open(f.name, encoding='UTF-8') for f in args.input] | |
| args.vocab = [codecs.open(f.name, 'w', encoding='UTF-8') for f in args.vocab] | |
| # get combined vocabulary of all input texts | |
| full_vocab = Counter() | |
| for f in args.input: | |
| full_vocab += learn_bpe.get_vocabulary(f, num_workers=args.num_workers) | |
| f.seek(0) | |
| vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()] | |
| # learn BPE on combined vocabulary | |
| with codecs.open(args.output.name, 'w', encoding='UTF-8') as output: | |
| learn_bpe.learn_bpe(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True, total_symbols=args.total_symbols) | |
| with codecs.open(args.output.name, encoding='UTF-8') as codes: | |
| bpe = apply_bpe.BPE(codes, separator=args.separator) | |
| # apply BPE to each training corpus and get vocabulary | |
| for train_file, vocab_file in zip(args.input, args.vocab): | |
| tmp = tempfile.NamedTemporaryFile(delete=False) | |
| tmp.close() | |
| tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8') | |
| train_file.seek(0) | |
| bpe.process_lines(train_file.name, tmpout, num_workers=args.num_workers) | |
| tmpout.close() | |
| tmpin = codecs.open(tmp.name, encoding='UTF-8') | |
| vocab = learn_bpe.get_vocabulary(tmpin, num_workers=args.num_workers) | |
| tmpin.close() | |
| os.remove(tmp.name) | |
| for key, freq in sorted(vocab.items(), key=lambda x: x[1], reverse=True): | |
| vocab_file.write("{0} {1}\n".format(key, freq)) | |
| vocab_file.close() | |
| if __name__ == '__main__': | |
| currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
| newdir = os.path.join(currentdir, 'subword_nmt') | |
| if os.path.isdir(newdir): | |
| warnings.simplefilter('default') | |
| warnings.warn( | |
| "this script's location has moved to {0}. This symbolic link will be removed in a future version. Please point to the new location, or install the package and use the command 'subword-nmt'".format(newdir), | |
| DeprecationWarning | |
| ) | |
| # python 2/3 compatibility | |
| if sys.version_info < (3, 0): | |
| sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) | |
| sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) | |
| sys.stdin = codecs.getreader('UTF-8')(sys.stdin) | |
| else: | |
| sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer) | |
| sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer) | |
| sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer) | |
| parser = create_parser() | |
| args = parser.parse_args() | |
| if args.num_workers <= 0: | |
| args.num_workers = cpu_count() | |
| if sys.version_info < (3, 0): | |
| args.separator = args.separator.decode('UTF-8') | |
| if args.num_workers > 1: | |
| args.num_workers = 1 | |
| warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.") | |
| assert(len(args.input) == len(args.vocab)) | |
| learn_joint_bpe_and_vocab(args) | |