Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gc | |
| from transformers import AutoModel, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import json | |
| import argparse | |
| import time | |
| from datetime import datetime, timedelta | |
| import re | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| def encode_text(model, tokenizer, texts, batch_size=8, max_length=512): | |
| """Encode texts to embeddings using AutoModel""" | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| # Tokenize | |
| encoded_input = tokenizer( | |
| batch, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors='pt' | |
| ).to(model.device) | |
| # Compute token embeddings | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| model_output = model(**encoded_input) | |
| # Use mean pooling | |
| attention_mask = encoded_input['attention_mask'] | |
| token_embeddings = model_output[0] # First element contains token embeddings | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| all_embeddings.append(embeddings.cpu().numpy()) | |
| # Clear some memory | |
| if i % (batch_size * 4) == 0: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return np.vstack(all_embeddings) | |
| def compute_similarity(emb1, emb2): | |
| """Compute cosine similarity between embeddings""" | |
| return np.dot(emb1, emb2.T) / ( | |
| np.linalg.norm(emb1, axis=1).reshape(-1, 1) * | |
| np.linalg.norm(emb2, axis=1).reshape(1, -1) | |
| ) | |
| def get_detailed_instruct(task_description: str, query: str) -> str: | |
| return f'Instruct: {task_description}\nQuery: {query}' | |
| def preprocess_sentences(sentence1, sentence2): | |
| vectorizer = TfidfVectorizer().fit_transform([sentence1, sentence2]) | |
| vectors = vectorizer.toarray() | |
| cosine_sim = cosine_similarity(vectors) | |
| similarity_score = cosine_sim[0][1] | |
| return similarity_score | |
| def remove_trailing_special_chars(text): | |
| return re.sub(r'[\W_]+$', '', text) | |
| def remove_special_chars_except_spaces(text): | |
| return re.sub(r'[^\w\s]+', '', text) | |
| def select_top_k(claim, results, top_k): | |
| ''' | |
| remove sentence of similarity claim | |
| ''' | |
| dup_check = set() | |
| top_k_sentences_urls = [] | |
| i = 0 | |
| # print(results) | |
| claim = remove_special_chars_except_spaces(claim).lower() | |
| while len(top_k_sentences_urls) < top_k and i < len(results): | |
| # print(i) | |
| sentence = remove_special_chars_except_spaces(results[i]['sentence']).lower() | |
| if sentence not in dup_check: | |
| if preprocess_sentences(claim, sentence) > 0.97: | |
| dup_check.add(sentence) | |
| continue | |
| if claim in sentence: | |
| if len(claim) / len(sentence) > 0.92: | |
| dup_check.add(sentence) | |
| continue | |
| top_k_sentences_urls.append({ | |
| 'sentence': results[i]['sentence'], | |
| 'url': results[i]['url']} | |
| ) | |
| i += 1 | |
| return top_k_sentences_urls | |
| # def format_time(seconds): | |
| # """Format time duration nicely.""" | |
| # return str(timedelta(seconds=round(seconds))) | |
| def compute_embeddings_batched(model, texts, batch_size=8): | |
| """Compute embeddings in smaller batches to manage memory""" | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Use bfloat16 | |
| emb = model.encode(batch, batch_size=len(batch), show_progress_bar=False) | |
| all_embeddings.append(emb) | |
| # Clear some memory | |
| if i % (batch_size * 4) == 0: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return np.vstack(all_embeddings) | |
| def main(args): | |
| device = "cuda" if torch.cuda.is_available() else 'cpu' | |
| print(f"Using device: {device}") | |
| # Load model and tokenizer | |
| model = AutoModel.from_pretrained( | |
| "Salesforce/SFR-Embedding-2_R", | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("Salesforce/SFR-Embedding-2_R") | |
| # Load target examples | |
| target_examples = [] | |
| with open(args.target_data, "r", encoding="utf-8") as json_file: | |
| for i, line in enumerate(json_file): | |
| try: | |
| example = json.loads(r"{}".format(line)) | |
| target_examples.append(example) | |
| except: | |
| print(f"CURRENT LINE broken {i}") | |
| if args.end == -1: | |
| args.end = len(target_examples) | |
| files_to_process = list(range(args.start, args.end)) | |
| total = len(files_to_process) | |
| task = 'Given a web search query, retrieve relevant passages that answer the query' | |
| with open(args.json_output, "w", encoding="utf-8") as output_json: | |
| done = 0 | |
| for idx, example in enumerate(target_examples): | |
| if idx in files_to_process: | |
| print(f"Processing claim {example['claim_id']}... Progress: {done + 1} / {total}") | |
| claim = example['claim'] | |
| query = [get_detailed_instruct(task, claim)] + [ | |
| get_detailed_instruct(task, le) | |
| for le in example['hypo_fc_docs'] | |
| if len(le.strip()) > 0 | |
| ] | |
| query_length = len(query) | |
| sentences = [sent['sentence'] for sent in example[f'top_{5000}']][:args.retrieved_top_k] | |
| # st = time.time() | |
| try: | |
| # Process query embeddings | |
| query_embeddings = encode_text(model, tokenizer, query, batch_size=4) | |
| avg_emb_q = np.mean(query_embeddings, axis=0) | |
| hyde_vector = avg_emb_q.reshape((1, -1)) | |
| # Process sentence embeddings in smaller chunks | |
| sentence_embeddings = encode_text( | |
| model, | |
| tokenizer, | |
| sentences, | |
| batch_size=args.batch_size | |
| ) | |
| # Compute similarities in chunks to save memory | |
| chunk_size = 1000 | |
| all_scores = [] | |
| for i in range(0, len(sentence_embeddings), chunk_size): | |
| chunk = sentence_embeddings[i:i + chunk_size] | |
| chunk_scores = compute_similarity(hyde_vector, chunk)[0] | |
| all_scores.extend(chunk_scores) | |
| scores = np.array(all_scores) | |
| top_k_idx = np.argsort(scores)[::-1] | |
| results = [example['top_5000'][i] for i in top_k_idx] | |
| top_k_sentences_urls = select_top_k(claim, results, args.top_k) | |
| # print(f"Top {args.top_k} retrieved. Time elapsed: {time.time() - st:.2f}s") | |
| json_data = { | |
| "claim_id": example['claim_id'], | |
| "claim": claim, | |
| f"top_{args.top_k}": top_k_sentences_urls | |
| } | |
| output_json.write(json.dumps(json_data, ensure_ascii=False) + "\n") | |
| output_json.flush() | |
| except RuntimeError as e: | |
| print(f"Error processing claim {example['claim_id']}: {e}") | |
| continue | |
| done += 1 | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--target_data", default="data_store/dev_retrieval_top_k.json") | |
| parser.add_argument("--retrieved_top_k", type=int, default=5000) | |
| parser.add_argument("--top_k", type=int, default=10) | |
| parser.add_argument("-o", "--json_output", type=str, default="data_store/dev_reranking_top_k.json") | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("-s", "--start", type=int, default=0) | |
| parser.add_argument("-e", "--end", type=int, default=-1) | |
| args = parser.parse_args() | |
| main(args) |