Spaces:
Runtime error
Runtime error
compositional_test
/
transformers
/examples
/research_projects
/codeparrot
/scripts
/preprocessing.py
| import gzip | |
| import hashlib | |
| import json | |
| import multiprocessing | |
| import os | |
| import re | |
| import shutil | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| from arguments import PreprocessingArguments | |
| from datasets import load_dataset | |
| from minhash_deduplication import deduplicate_dataset | |
| from transformers import AutoTokenizer, HfArgumentParser | |
| PATTERN = re.compile(r"\s+") | |
| def get_hash(example): | |
| """Get hash of content field.""" | |
| return {"hash": hashlib.md5(re.sub(PATTERN, "", example["content"]).encode("utf-8")).hexdigest()} | |
| def line_stats(example): | |
| """Calculates mean and max line length of file.""" | |
| line_lengths = [len(line) for line in example["content"].splitlines()] | |
| return {"line_mean": np.mean(line_lengths), "line_max": max(line_lengths)} | |
| def alpha_stats(example): | |
| """Calculates mean and max line length of file.""" | |
| alpha_frac = np.mean([c.isalnum() for c in example["content"]]) | |
| return {"alpha_frac": alpha_frac} | |
| def check_uniques(example, uniques): | |
| """Check if current hash is still in set of unique hashes and remove if true.""" | |
| if example["hash"] in uniques: | |
| uniques.remove(example["hash"]) | |
| return True | |
| else: | |
| return False | |
| def is_autogenerated(example, scan_width=5): | |
| """Check if file is autogenerated by looking for keywords in the first few lines of the file.""" | |
| keywords = ["auto-generated", "autogenerated", "automatically generated"] | |
| lines = example["content"].splitlines() | |
| for _, line in zip(range(scan_width), lines): | |
| for keyword in keywords: | |
| if keyword in line.lower(): | |
| return {"autogenerated": True} | |
| else: | |
| return {"autogenerated": False} | |
| def is_config_or_test(example, scan_width=5, coeff=0.05): | |
| """Check if file is a configuration file or a unit test by : | |
| 1- looking for keywords in the first few lines of the file. | |
| 2- counting number of occurence of the words 'config' and 'test' with respect to number of lines. | |
| """ | |
| keywords = ["unit tests", "test file", "configuration file"] | |
| lines = example["content"].splitlines() | |
| count_config = 0 | |
| count_test = 0 | |
| # first test | |
| for _, line in zip(range(scan_width), lines): | |
| for keyword in keywords: | |
| if keyword in line.lower(): | |
| return {"config_or_test": True} | |
| # second test | |
| nlines = example["content"].count("\n") | |
| threshold = int(coeff * nlines) | |
| for line in lines: | |
| count_config += line.lower().count("config") | |
| count_test += line.lower().count("test") | |
| if count_config > threshold or count_test > threshold: | |
| return {"config_or_test": True} | |
| return {"config_or_test": False} | |
| def has_no_keywords(example): | |
| """Check if a python file has none of the keywords for: funcion, class, for loop, while loop.""" | |
| keywords = ["def ", "class ", "for ", "while "] | |
| lines = example["content"].splitlines() | |
| for line in lines: | |
| for keyword in keywords: | |
| if keyword in line.lower(): | |
| return {"has_no_keywords": False} | |
| return {"has_no_keywords": True} | |
| def has_few_assignments(example, minimum=4): | |
| """Check if file uses symbol '=' less than `minimum` times.""" | |
| lines = example["content"].splitlines() | |
| counter = 0 | |
| for line in lines: | |
| counter += line.lower().count("=") | |
| if counter > minimum: | |
| return {"has_few_assignments": False} | |
| return {"has_few_assignments": True} | |
| def char_token_ratio(example): | |
| """Compute character/token ratio of the file with tokenizer.""" | |
| input_ids = tokenizer(example["content"], truncation=False)["input_ids"] | |
| ratio = len(example["content"]) / len(input_ids) | |
| return {"ratio": ratio} | |
| def preprocess(example): | |
| """Chain all preprocessing steps into one function to not fill cache.""" | |
| results = {} | |
| results.update(get_hash(example)) | |
| results.update(line_stats(example)) | |
| results.update(alpha_stats(example)) | |
| results.update(char_token_ratio(example)) | |
| results.update(is_autogenerated(example)) | |
| results.update(is_config_or_test(example)) | |
| results.update(has_no_keywords(example)) | |
| results.update(has_few_assignments(example)) | |
| return results | |
| def filter(example, uniques, args): | |
| """Filter dataset with heuristics. Config, test and has_no_keywords files are removed with a given probability.""" | |
| if not check_uniques(example, uniques): | |
| return False | |
| elif example["autogenerated"]: | |
| return False | |
| elif example["line_max"] > args.line_max: | |
| return False | |
| elif example["line_mean"] > args.line_mean: | |
| return False | |
| elif example["alpha_frac"] < args.alpha_frac: | |
| return False | |
| elif example["ratio"] < args.min_token_ratio: | |
| return False | |
| elif example["config_or_test"] and np.random.rand() <= args.filter_proba: | |
| return False | |
| elif example["has_no_keywords"] and np.random.rand() <= args.filter_proba: | |
| return False | |
| elif example["has_few_assignments"]: | |
| return False | |
| else: | |
| return True | |
| def compress_file(file_path): | |
| """Compress a file with g-zip.""" | |
| with open(file_path, "rb") as f_in: | |
| with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out: | |
| shutil.copyfileobj(f_in, f_out) | |
| os.unlink(file_path) | |
| # Settings | |
| parser = HfArgumentParser(PreprocessingArguments) | |
| args = parser.parse_args() | |
| if args.num_workers is None: | |
| args.num_workers = multiprocessing.cpu_count() | |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) | |
| # Load dataset | |
| t_start = time.time() | |
| ds = load_dataset(args.dataset_name, split="train") | |
| print(f"Time to load dataset: {time.time()-t_start:.2f}") | |
| # Run preprocessing | |
| t_start = time.time() | |
| ds = ds.map(preprocess, num_proc=args.num_workers) | |
| print(f"Time to preprocess dataset: {time.time()-t_start:.2f}") | |
| # Deduplicate hashes | |
| uniques = set(ds.unique("hash")) | |
| frac = len(uniques) / len(ds) | |
| print(f"Fraction of duplicates: {1-frac:.2%}") | |
| # Deduplicate data and apply heuristics | |
| t_start = time.time() | |
| ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args}) | |
| print(f"Time to filter dataset: {time.time()-t_start:.2f}") | |
| print(f"Size of filtered dataset: {len(ds_filter)}") | |
| # Deduplicate with minhash and jaccard similarity | |
| if args.near_deduplication: | |
| t_start = time.time() | |
| ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold) | |
| print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}") | |
| print(f"Size of deduplicate dataset: {len(ds_filter)}") | |
| # Save data in batches of samples_per_file | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(exist_ok=True) | |
| # save duplicate_clusters in the output_dir as artifacts | |
| # not sure it is the right place the save it | |
| if args.near_deduplication: | |
| with open(output_dir / "duplicate_clusters.json", "w") as f: | |
| json.dump(duplicate_clusters, f) | |
| data_dir = output_dir / "data" | |
| data_dir.mkdir(exist_ok=True) | |
| t_start = time.time() | |
| for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)): | |
| file_path = str(data_dir / f"file-{file_number+1:012}.json") | |
| end_index = min(len(ds_filter), index + args.samples_per_file) | |
| ds_filter.select(list(range(index, end_index))).to_json(file_path) | |
| compress_file(file_path) | |
| print(f"Time to save dataset: {time.time()-t_start:.2f}") | |