Spaces:
Sleeping
Sleeping
| import argparse | |
| import json | |
| import os | |
| import time | |
| import numpy as np | |
| import pandas as pd | |
| import nltk | |
| from rank_bm25 import BM25Okapi | |
| from multiprocessing import Pool, cpu_count, Manager, Lock | |
| from functools import partial | |
| import heapq | |
| from threading import Thread, Event | |
| import queue | |
| from datetime import datetime, timedelta | |
| def download_nltk_data(package_name, download_dir='nltk_data'): | |
| # Ensure the download directory exists | |
| os.makedirs(download_dir, exist_ok=True) | |
| # Set NLTK data path | |
| nltk.data.path.append(download_dir) | |
| try: | |
| # Try to find the resource | |
| nltk.data.find(f'tokenizers/{package_name}') | |
| print(f"Package '{package_name}' is already downloaded") | |
| except LookupError: | |
| # If resource isn't found, download it | |
| print(f"Downloading {package_name}...") | |
| nltk.download(package_name, download_dir=download_dir) | |
| print(f"Successfully downloaded {package_name}") | |
| def combine_all_sentences(knowledge_file): | |
| sentences, urls = [], [] | |
| with open(knowledge_file, "r", encoding="utf-8") as json_file: | |
| for i, line in enumerate(json_file): | |
| data = json.loads(line) | |
| sentences.extend(data["url2text"]) | |
| urls.extend([data["url"] for _ in range(len(data["url2text"]))]) | |
| return sentences, urls, i + 1 | |
| def remove_duplicates(sentences, urls): | |
| df = pd.DataFrame({"document_in_sentences":sentences, "sentence_urls":urls}) | |
| df['sentences'] = df['document_in_sentences'].str.strip().str.lower() | |
| df = df.drop_duplicates(subset="sentences").reset_index() | |
| return df['document_in_sentences'].tolist(), df['sentence_urls'].tolist() | |
| def retrieve_top_k_sentences(query, document, urls, top_k): | |
| tokenized_docs = [nltk.word_tokenize(doc) for doc in document[:top_k]] | |
| bm25 = BM25Okapi(tokenized_docs) | |
| scores = bm25.get_scores(nltk.word_tokenize(query)) | |
| top_k_idx = np.argsort(scores)[::-1][:top_k] | |
| return [document[i] for i in top_k_idx], [urls[i] for i in top_k_idx] | |
| def process_single_example(idx, example, args, result_queue, counter, lock): | |
| try: | |
| with lock: | |
| current_count = counter.value + 1 | |
| counter.value = current_count | |
| print(f"\nProcessing claim {idx}... Progress: {current_count} / {args.total_examples}") | |
| # start_time = time.time() | |
| document_in_sentences, sentence_urls, num_urls_this_claim = combine_all_sentences( | |
| os.path.join(args.knowledge_store_dir, f"{idx}.jsonl") | |
| ) | |
| print(f"Obtained {len(document_in_sentences)} sentences from {num_urls_this_claim} urls.") | |
| document_in_sentences, sentence_urls = remove_duplicates(document_in_sentences, sentence_urls) | |
| query = example["claim"] + " " + " ".join(example['hypo_fc_docs']) | |
| top_k_sentences, top_k_urls = retrieve_top_k_sentences( | |
| query, document_in_sentences, sentence_urls, args.top_k | |
| ) | |
| result = { | |
| "claim_id": idx, | |
| "claim": example["claim"], | |
| f"top_{args.top_k}": [ | |
| {"sentence": sent, "url": url} | |
| for sent, url in zip(top_k_sentences, top_k_urls) | |
| ], | |
| "hypo_fc_docs": example['hypo_fc_docs'] | |
| } | |
| result_queue.put((idx, result)) | |
| return True | |
| except Exception as e: | |
| print(f"Error processing example {idx}: {str(e)}") | |
| result_queue.put((idx, None)) | |
| return False | |
| def writer_thread(output_file, result_queue, total_examples, stop_event): | |
| next_index = 0 | |
| pending_results = [] | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| while not (stop_event.is_set() and result_queue.empty()): | |
| try: | |
| idx, result = result_queue.get(timeout=1) | |
| if result is not None: | |
| heapq.heappush(pending_results, (idx, result)) | |
| while pending_results and pending_results[0][0] == next_index: | |
| _, result_to_write = heapq.heappop(pending_results) | |
| f.write(json.dumps(result_to_write, ensure_ascii=False) + "\n") | |
| f.flush() | |
| next_index += 1 | |
| except queue.Empty: | |
| continue | |
| # def format_time(seconds): | |
| # """Format time duration nicely.""" | |
| # return str(timedelta(seconds=round(seconds))) | |
| def main(args): | |
| download_nltk_data('punkt') | |
| download_nltk_data('punkt_tab') | |
| with open(args.target_data, "r", encoding="utf-8") as json_file: | |
| target_examples = json.load(json_file) | |
| if args.end == -1: | |
| args.end = len(target_examples) | |
| print(f"Total examples to process: {args.end - args.start}") | |
| files_to_process = list(range(args.start, args.end)) | |
| examples_to_process = [(idx, target_examples[idx]) for idx in files_to_process] | |
| num_workers = min(args.workers if args.workers > 0 else cpu_count(), len(files_to_process)) | |
| print(f"Using {num_workers} workers to process {len(files_to_process)} examples") | |
| with Manager() as manager: | |
| counter = manager.Value('i', 0) | |
| lock = manager.Lock() | |
| args.total_examples = len(files_to_process) | |
| result_queue = manager.Queue() | |
| stop_event = Event() | |
| writer = Thread( | |
| target=writer_thread, | |
| args=(args.json_output, result_queue, len(files_to_process), stop_event) | |
| ) | |
| writer.start() | |
| process_func = partial( | |
| process_single_example, | |
| args=args, | |
| result_queue=result_queue, | |
| counter=counter, | |
| lock=lock | |
| ) | |
| with Pool(num_workers) as pool: | |
| results = pool.starmap(process_func, examples_to_process) | |
| stop_event.set() | |
| writer.join() | |
| # successful = sum(1 for r in results if r) | |
| # print(f"\nSuccessfully processed {successful} out of {len(files_to_process)} examples") | |
| # print(f"Results written to {args.json_output}") | |
| # # Calculate and display timing information | |
| # total_time = time.time() - script_start | |
| # avg_time = total_time / len(files_to_process) | |
| # end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # print("\nTiming Summary:") | |
| # print(f"Start time: {start_time}") | |
| # print(f"End time: {end_time}") | |
| # print(f"Total runtime: {format_time(total_time)} (HH:MM:SS)") | |
| # print(f"Average time per example: {avg_time:.2f} seconds") | |
| # if successful > 0: | |
| # print(f"Processing speed: {successful / total_time:.2f} examples per second") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Get top 10000 sentences with BM25 in the knowledge store using parallel processing." | |
| ) | |
| parser.add_argument( | |
| "-k", | |
| "--knowledge_store_dir", | |
| type=str, | |
| default="data_store/knowledge_store", | |
| help="The path of the knowledge_store_dir containing json files with all the retrieved sentences.", | |
| ) | |
| parser.add_argument( | |
| "--target_data", | |
| type=str, | |
| default="data_store/hyde_fc.json", | |
| help="The path of the file that stores the claim.", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--json_output", | |
| type=str, | |
| default="data_store/dev_retrieval_top_k.json", | |
| help="The output dir for JSON files to save the top 100 sentences for each claim.", | |
| ) | |
| parser.add_argument( | |
| "--top_k", | |
| default=5000, | |
| type=int, | |
| help="How many documents should we pick out with BM25.", | |
| ) | |
| parser.add_argument( | |
| "-s", | |
| "--start", | |
| type=int, | |
| default=0, | |
| help="Starting index of the files to process.", | |
| ) | |
| parser.add_argument( | |
| "-e", | |
| "--end", | |
| type=int, | |
| default=-1, | |
| help="End index of the files to process.", | |
| ) | |
| parser.add_argument( | |
| "-w", | |
| "--workers", | |
| type=int, | |
| default=0, | |
| help="Number of worker processes (default: number of CPU cores)", | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |