Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList | |
| from .prompts import format_rag_prompt | |
| from .shared import generation_interrupt | |
| import threading | |
| import queue | |
| import time # Added for sleep | |
| models = { | |
| "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct", | |
| "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct", | |
| "Gemma-3-1b-it": "google/gemma-3-1b-it", | |
| } | |
| # List of model names for easy access | |
| model_names = list(models.keys()) | |
| # Custom stopping criteria that checks the interrupt flag | |
| class InterruptCriteria(StoppingCriteria): | |
| def __init__(self, interrupt_event): | |
| self.interrupt_event = interrupt_event | |
| def __call__(self, input_ids, scores, **kwargs): | |
| return self.interrupt_event.is_set() | |
| def generate_summaries(example, model_a_name, model_b_name): | |
| """ | |
| Generates summaries for the given example using the assigned models. | |
| """ | |
| if generation_interrupt.is_set(): | |
| return "", "" | |
| context_text = "" | |
| context_parts = [] | |
| if "full_contexts" in example: | |
| for ctx in example["full_contexts"]: | |
| if isinstance(ctx, dict) and "content" in ctx: | |
| context_parts.append(ctx["content"]) | |
| context_text = "\n---\n".join(context_parts) | |
| else: | |
| raise ValueError("No context found in the example.") | |
| question = example.get("question", "") | |
| if generation_interrupt.is_set(): | |
| return "", "" | |
| # Use a queue to get results from threads | |
| result_queue_a = queue.Queue() | |
| thread_a = threading.Thread(target=run_inference, args=(models[model_a_name], context_text, question, result_queue_a)) | |
| thread_a.start() | |
| summary_a = "" | |
| while thread_a.is_alive(): | |
| if generation_interrupt.is_set(): | |
| print(f"Interrupting model A ({model_a_name})...") | |
| # The InterruptCriteria within the thread will handle stopping generate | |
| # We return early from the main control flow. | |
| thread_a.join(timeout=1.0) # Give thread a moment to potentially stop | |
| return "", "" | |
| try: | |
| summary_a = result_queue_a.get(timeout=0.1) # Check queue periodically | |
| break # Got result | |
| except queue.Empty: | |
| continue # Still running, check interrupt again | |
| # If thread finished but we didn't get a result (e.g., interrupted just before putting in queue) | |
| if not summary_a and not result_queue_a.empty(): | |
| summary_a = result_queue_a.get_nowait() | |
| elif not summary_a and generation_interrupt.is_set(): # Check interrupt again if thread finished quickly | |
| return "", "" | |
| if generation_interrupt.is_set(): # Check between models | |
| return summary_a, "" | |
| # --- Model B --- | |
| result_queue_b = queue.Queue() | |
| thread_b = threading.Thread(target=run_inference, args=(models[model_b_name], context_text, question, result_queue_b)) | |
| thread_b.start() | |
| summary_b = "" | |
| while thread_b.is_alive(): | |
| if generation_interrupt.is_set(): | |
| print(f"Interrupting model B ({model_b_name})...") | |
| thread_b.join(timeout=1.0) | |
| return summary_a, "" # Return summary_a obtained so far | |
| try: | |
| summary_b = result_queue_b.get(timeout=0.1) | |
| break | |
| except queue.Empty: | |
| continue | |
| if not summary_b and not result_queue_b.empty(): | |
| summary_b = result_queue_b.get_nowait() | |
| elif not summary_b and generation_interrupt.is_set(): | |
| return summary_a, "" | |
| return summary_a, summary_b | |
| # Modified run_inference to run in a thread and use a queue for results | |
| def run_inference(model_name, context, question, result_queue): | |
| """ | |
| Run inference using the specified model. Designed to be run in a thread. | |
| Puts the result or an error string into the result_queue. | |
| """ | |
| # Check interrupt at the very beginning of the thread | |
| if generation_interrupt.is_set(): | |
| result_queue.put("") | |
| return | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = None | |
| tokenizer = None | |
| result = "" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True) | |
| accepts_sys = ( | |
| "System role not supported" not in tokenizer.chat_template | |
| if tokenizer.chat_template else False # Handle missing chat_template | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Check interrupt before loading the model | |
| if generation_interrupt.is_set(): | |
| result_queue.put("") | |
| return | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True | |
| ).to(device) | |
| model.eval() # Set model to evaluation mode | |
| text_input = format_rag_prompt(question, context, accepts_sys) | |
| # Check interrupt before tokenization/template application | |
| if generation_interrupt.is_set(): | |
| result_queue.put("") | |
| return | |
| actual_input = tokenizer.apply_chat_template( | |
| text_input, | |
| return_tensors="pt", | |
| tokenize=True, | |
| # Consider reducing max_length if context/question is very long | |
| # max_length=tokenizer.model_max_length, # Use model's max length | |
| # truncation=True, # Ensure truncation if needed | |
| max_length=2048, # Keep original max_length for now | |
| add_generation_prompt=True, | |
| ).to(device) | |
| # Ensure input does not exceed model max length after adding generation prompt | |
| # This check might be redundant if tokenizer handles it, but good for safety | |
| # if actual_input.shape[1] > tokenizer.model_max_length: | |
| # # Handle too long input - maybe truncate manually or raise error | |
| # print(f"Warning: Input length {actual_input.shape[1]} exceeds model max length {tokenizer.model_max_length}") | |
| # # Simple truncation (might lose important info): | |
| # # actual_input = actual_input[:, -tokenizer.model_max_length:] | |
| input_length = actual_input.shape[1] | |
| attention_mask = torch.ones_like(actual_input).to(device) | |
| # Check interrupt before generation | |
| if generation_interrupt.is_set(): | |
| result_queue.put("") | |
| return | |
| stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)]) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| actual_input, | |
| attention_mask=attention_mask, | |
| max_new_tokens=512, | |
| pad_token_id=tokenizer.pad_token_id, | |
| stopping_criteria=stopping_criteria, | |
| do_sample=True, # Consider adding sampling parameters if needed | |
| temperature=0.6, | |
| top_p=0.9, | |
| ) | |
| # Check interrupt immediately after generation finishes or stops | |
| if generation_interrupt.is_set(): | |
| result = "" # Discard potentially partial result if interrupted | |
| else: | |
| # Decode the generated tokens, excluding the input tokens | |
| result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) | |
| result_queue.put(result) | |
| except Exception as e: | |
| print(f"Error in inference thread for {model_name}: {e}") | |
| # Put error message in queue for the main thread to handle/display | |
| result_queue.put(f"Error generating response: {str(e)[:100]}...") | |
| finally: | |
| # Clean up resources within the thread | |
| del model | |
| del tokenizer | |
| del actual_input | |
| del outputs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() |