Spaces:
Sleeping
Sleeping
| from vllm import LLM, SamplingParams | |
| import json | |
| import torch | |
| import time | |
| from datetime import datetime, timedelta | |
| import argparse | |
| from tqdm import tqdm | |
| from typing import List, Dict, Any | |
| import concurrent.futures | |
| class VLLMGenerator: | |
| def __init__(self, model_name: str, n: int = 8, max_tokens: int = 512, | |
| temperature: float = 0.7, top_p: float = 1.0, | |
| frequency_penalty: float = 0.0, presence_penalty: float = 0.0, | |
| stop: List[str] = ['\n\n\n'], batch_size: int = 32): | |
| self.device_count = torch.cuda.device_count() | |
| print(f"Initializing with {self.device_count} GPUs") | |
| self.llm = LLM( | |
| model=model_name, | |
| tensor_parallel_size=self.device_count, | |
| max_model_len=4096, | |
| gpu_memory_utilization=0.95, | |
| enforce_eager=True, | |
| trust_remote_code=True, | |
| # quantization="bitsandbytes", | |
| # dtype="half", | |
| # load_format="bitsandbytes", | |
| max_num_batched_tokens=4096, | |
| max_num_seqs=batch_size | |
| ) | |
| self.sampling_params = SamplingParams( | |
| n=n, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty, | |
| stop=stop, | |
| logprobs=1 | |
| ) | |
| self.batch_size = batch_size | |
| self.tokenizer = self.llm.get_tokenizer() | |
| print(f"Initialization complete. Batch size: {batch_size}") | |
| def parse_response(self, responses): | |
| all_outputs = [] | |
| for response in responses: | |
| to_return = [] | |
| for output in response.outputs: | |
| text = output.text.strip() | |
| try: | |
| logprob = sum(logprob_obj.logprob for item in output.logprobs for logprob_obj in item.values()) | |
| except: | |
| logprob = 0 # Fallback if logprobs aren't available | |
| to_return.append((text, logprob)) | |
| texts = [r[0] for r in sorted(to_return, key=lambda tup: tup[1], reverse=True)] | |
| all_outputs.append(texts) | |
| return all_outputs | |
| def prepare_prompt(self, claim: str, model_name: str) -> str: | |
| base_prompt = f"Please write a fact-checking article passage to support, refute, indicate not enough evidence, or present conflicting evidence regarding the claim.\nClaim: {claim}" | |
| if "OLMo" in model_name: | |
| return base_prompt | |
| else: | |
| messages = [{"role": "user", "content": base_prompt}] | |
| return self.tokenizer.apply_chat_template(messages, tokenize=False) + "<|start_header_id|>assistant<|end_header_id|>\n\nPassage: " | |
| def process_batch(self, batch: List[Dict[str, Any]], model_name: str) -> tuple[List[Dict[str, Any]], float]: | |
| start_time = time.time() | |
| prompts = [self.prepare_prompt(example["claim"], model_name) for example in batch] | |
| try: | |
| results = self.llm.generate(prompts, sampling_params=self.sampling_params) | |
| outputs = self.parse_response(results) | |
| for example, output in zip(batch, outputs): | |
| example['hypo_fc_docs'] = output | |
| batch_time = time.time() - start_time | |
| return batch, batch_time | |
| except Exception as e: | |
| print(f"Error processing batch: {str(e)}") | |
| return batch, time.time() - start_time | |
| # def format_time(seconds: float) -> str: | |
| # return str(timedelta(seconds=int(seconds))) | |
| # def estimate_completion_time(start_time: float, processed_examples: int, total_examples: int) -> str: | |
| # elapsed_time = time.time() - start_time | |
| # examples_per_second = processed_examples / elapsed_time | |
| # remaining_examples = total_examples - processed_examples | |
| # estimated_remaining_seconds = remaining_examples / examples_per_second | |
| # completion_time = datetime.now() + timedelta(seconds=int(estimated_remaining_seconds)) | |
| # return completion_time.strftime("%Y-%m-%d %H:%M:%S") | |
| def main(args): | |
| total_start_time = time.time() | |
| print(f"Script started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| # Load data | |
| print("Loading data...") | |
| with open(args.target_data, 'r', encoding='utf-8') as json_file: | |
| examples = json.load(json_file) | |
| print(f"Loaded {len(examples)} examples") | |
| # Initialize generator | |
| print("Initializing generator...") | |
| generator = VLLMGenerator( | |
| model_name=args.model, | |
| batch_size=32 | |
| ) | |
| # Process data in batches | |
| processed_data = [] | |
| # batch_times = [] | |
| batches = [examples[i:i + generator.batch_size] for i in range(0, len(examples), generator.batch_size)] | |
| print(f"\nProcessing {len(batches)} batches...") | |
| with tqdm(total=len(examples), desc="Processing examples") as pbar: | |
| for batch_idx, batch in enumerate(batches, 1): | |
| processed_batch, batch_time = generator.process_batch(batch, args.model) | |
| processed_data.extend(processed_batch) | |
| # batch_times.append(batch_time) | |
| # Update progress and timing information | |
| # examples_processed = len(processed_data) | |
| # avg_batch_time = sum(batch_times) / len(batch_times) | |
| # estimated_completion = estimate_completion_time(total_start_time, examples_processed, len(examples)) | |
| # pbar.set_postfix({ | |
| # 'Batch': f"{batch_idx}/{len(batches)}", | |
| # 'Avg Batch Time': f"{avg_batch_time:.2f}s", | |
| # 'ETA': estimated_completion | |
| # }) | |
| # pbar.update(len(batch)) | |
| # Calculate and display timing statistics | |
| # total_time = time.time() - total_start_time | |
| # avg_batch_time = sum(batch_times) / len(batch_times) | |
| # avg_example_time = total_time / len(examples) | |
| # print("\nTiming Statistics:") | |
| # print(f"Total Runtime: {format_time(total_time)}") | |
| # print(f"Average Batch Time: {avg_batch_time:.2f} seconds") | |
| # print(f"Average Time per Example: {avg_example_time:.2f} seconds") | |
| # print(f"Throughput: {len(examples)/total_time:.2f} examples/second") | |
| # Save results | |
| # print("\nSaving results...") | |
| with open(args.json_output, "w", encoding="utf-8") as output_json: | |
| json.dump(processed_data, output_json, ensure_ascii=False, indent=4) | |
| # print(f"Script completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| # print(f"Total runtime: {format_time(total_time)}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-i', '--target_data', default='data_store/averitec/dev.json') | |
| parser.add_argument('-o', '--json_output', default='data_store/hyde_fc.json') | |
| parser.add_argument('-m', '--model', default="meta-llama/Llama-3.1-8B-Instruct") | |
| args = parser.parse_args() | |
| main(args) |