Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Roll back interruption changes
Browse files- utils/models.py +13 -69
    	
        utils/models.py
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
             
            import os
         | 
| 2 | 
            -
            #  | 
| 3 | 
             
            import torch._dynamo
         | 
| 4 | 
             
            torch._dynamo.config.suppress_errors = True
         | 
| 5 |  | 
| @@ -17,7 +17,8 @@ from transformers import ( | |
| 17 | 
             
                BitNetForCausalLM
         | 
| 18 | 
             
            )
         | 
| 19 | 
             
            from .prompts import format_rag_prompt
         | 
| 20 | 
            -
             | 
|  | |
| 21 |  | 
| 22 | 
             
            models = {
         | 
| 23 | 
             
                "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
         | 
| @@ -47,13 +48,13 @@ tokenizer_cache = {} | |
| 47 | 
             
            model_names = list(models.keys())
         | 
| 48 |  | 
| 49 |  | 
| 50 | 
            -
            #  | 
| 51 | 
            -
            class InterruptCriteria(StoppingCriteria):
         | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 |  | 
| 58 |  | 
| 59 | 
             
            @spaces.GPU
         | 
| @@ -61,20 +62,12 @@ def generate_summaries(example, model_a_name, model_b_name): | |
| 61 | 
             
                """
         | 
| 62 | 
             
                Generates summaries for the given example using the assigned models sequentially.
         | 
| 63 | 
             
                """
         | 
| 64 | 
            -
                 | 
| 65 | 
            -
                    print("Generation interrupted before starting")
         | 
| 66 | 
            -
                    return "", ""
         | 
| 67 | 
            -
             | 
| 68 | 
             
                context_text = ""
         | 
| 69 | 
             
                context_parts = []
         | 
| 70 |  | 
| 71 | 
             
                if "full_contexts" in example and example["full_contexts"]:
         | 
| 72 | 
             
                    for i, ctx in enumerate(example["full_contexts"]):
         | 
| 73 | 
            -
                        # Check interrupt during context processing
         | 
| 74 | 
            -
                        if generation_interrupt.is_set():
         | 
| 75 | 
            -
                            print("Generation interrupted during context processing")
         | 
| 76 | 
            -
                            return "", ""
         | 
| 77 | 
            -
                            
         | 
| 78 | 
             
                        content = ""
         | 
| 79 |  | 
| 80 | 
             
                        # Extract content from either dict or string
         | 
| @@ -97,18 +90,10 @@ def generate_summaries(example, model_a_name, model_b_name): | |
| 97 |  | 
| 98 | 
             
                question = example.get("question", "")
         | 
| 99 |  | 
| 100 | 
            -
                if generation_interrupt.is_set():
         | 
| 101 | 
            -
                    print("Generation interrupted before model A")
         | 
| 102 | 
            -
                    return "", ""
         | 
| 103 | 
            -
             | 
| 104 | 
             
                print(f"Starting inference for Model A: {model_a_name}")
         | 
| 105 | 
             
                # Run model A
         | 
| 106 | 
             
                summary_a = run_inference(models[model_a_name], context_text, question)
         | 
| 107 |  | 
| 108 | 
            -
                if generation_interrupt.is_set():
         | 
| 109 | 
            -
                    print("Generation interrupted after model A, before model B")
         | 
| 110 | 
            -
                    return summary_a, ""
         | 
| 111 | 
            -
             | 
| 112 | 
             
                print(f"Starting inference for Model B: {model_b_name}")
         | 
| 113 | 
             
                # Run model B
         | 
| 114 | 
             
                summary_b = run_inference(models[model_b_name], context_text, question)
         | 
| @@ -121,13 +106,8 @@ def generate_summaries(example, model_a_name, model_b_name): | |
| 121 | 
             
            def run_inference(model_name, context, question):
         | 
| 122 | 
             
                """
         | 
| 123 | 
             
                Run inference using the specified model.
         | 
| 124 | 
            -
                Returns the generated text | 
| 125 | 
             
                """
         | 
| 126 | 
            -
                # Check interrupt at the beginning
         | 
| 127 | 
            -
                if generation_interrupt.is_set():
         | 
| 128 | 
            -
                    print(f"Inference interrupted before starting for {model_name}")
         | 
| 129 | 
            -
                    return ""
         | 
| 130 | 
            -
             | 
| 131 | 
             
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 132 | 
             
                result = ""
         | 
| 133 | 
             
                tokenizer_kwargs = {
         | 
| @@ -146,11 +126,6 @@ def run_inference(model_name, context, question): | |
| 146 | 
             
                    if model_name in tokenizer_cache:
         | 
| 147 | 
             
                        tokenizer = tokenizer_cache[model_name]
         | 
| 148 | 
             
                    else:
         | 
| 149 | 
            -
                        # Check interrupt before loading tokenizer
         | 
| 150 | 
            -
                        if generation_interrupt.is_set():
         | 
| 151 | 
            -
                            print(f"Inference interrupted before loading tokenizer for {model_name}")
         | 
| 152 | 
            -
                            return ""
         | 
| 153 | 
            -
                            
         | 
| 154 | 
             
                        # Common arguments for tokenizer loading
         | 
| 155 | 
             
                        tokenizer_load_args = {"padding_side": "left", "token": True}
         | 
| 156 |  | 
| @@ -170,21 +145,8 @@ def run_inference(model_name, context, question): | |
| 170 | 
             
                    if tokenizer.pad_token is None:
         | 
| 171 | 
             
                        tokenizer.pad_token = tokenizer.eos_token
         | 
| 172 |  | 
| 173 | 
            -
                    # Check interrupt before loading the model
         | 
| 174 | 
            -
                    if generation_interrupt.is_set():
         | 
| 175 | 
            -
                        print(f"Inference interrupted before loading model {model_name}")
         | 
| 176 | 
            -
                        return ""
         | 
| 177 | 
            -
             | 
| 178 | 
            -
                    # Create interrupt criteria for this generation
         | 
| 179 | 
            -
                    interrupt_criteria = InterruptCriteria(generation_interrupt)
         | 
| 180 | 
            -
             | 
| 181 | 
             
                    print("REACHED HERE BEFORE pipe")
         | 
| 182 | 
             
                    print(f"Loading model {model_name}...")
         | 
| 183 | 
            -
                    
         | 
| 184 | 
            -
                    # Check interrupt before model loading
         | 
| 185 | 
            -
                    if generation_interrupt.is_set():
         | 
| 186 | 
            -
                        print(f"Inference interrupted during model loading for {model_name}")
         | 
| 187 | 
            -
                        return ""
         | 
| 188 |  | 
| 189 | 
             
                    if "bitnet" in model_name.lower():
         | 
| 190 | 
             
                        bitnet_model = BitNetForCausalLM.from_pretrained(
         | 
| @@ -226,11 +188,6 @@ def run_inference(model_name, context, question): | |
| 226 | 
             
                            torch_dtype=torch.bfloat16,
         | 
| 227 | 
             
                        )
         | 
| 228 |  | 
| 229 | 
            -
                    # Final interrupt check before generation
         | 
| 230 | 
            -
                    if generation_interrupt.is_set():
         | 
| 231 | 
            -
                        print(f"Inference interrupted before generation for {model_name}")
         | 
| 232 | 
            -
                        return ""
         | 
| 233 | 
            -
             | 
| 234 | 
             
                    text_input = format_rag_prompt(question, context, accepts_sys)
         | 
| 235 |  | 
| 236 | 
             
                    print(f"Starting generation for {model_name}")
         | 
| @@ -239,7 +196,6 @@ def run_inference(model_name, context, question): | |
| 239 | 
             
                        result = pipe(
         | 
| 240 | 
             
                            text_input,
         | 
| 241 | 
             
                            max_new_tokens=512,
         | 
| 242 | 
            -
                            stopping_criteria=[interrupt_criteria],
         | 
| 243 | 
             
                            generation_kwargs={"skip_special_tokens": True}
         | 
| 244 | 
             
                        )[0]["generated_text"]
         | 
| 245 |  | 
| @@ -263,18 +219,12 @@ def run_inference(model_name, context, question): | |
| 263 | 
             
                        prompt_tokens_length = input_ids.shape[1] 
         | 
| 264 |  | 
| 265 | 
             
                        with torch.inference_mode():
         | 
| 266 | 
            -
                            # Check interrupt before generation
         | 
| 267 | 
            -
                            if generation_interrupt.is_set():
         | 
| 268 | 
            -
                                print(f"Inference interrupted before torch generation for {model_name}")
         | 
| 269 | 
            -
                                return ""
         | 
| 270 | 
            -
                            
         | 
| 271 | 
             
                            output_sequences = model.generate(
         | 
| 272 | 
             
                                input_ids=input_ids,
         | 
| 273 | 
             
                                attention_mask=attention_mask,
         | 
| 274 | 
             
                                max_new_tokens=512,
         | 
| 275 | 
             
                                eos_token_id=tokenizer.eos_token_id, 
         | 
| 276 | 
            -
                                pad_token_id=tokenizer.pad_token_id | 
| 277 | 
            -
                                stopping_criteria=[interrupt_criteria]
         | 
| 278 | 
             
                            )
         | 
| 279 |  | 
| 280 | 
             
                        generated_token_ids = output_sequences[0][prompt_tokens_length:]
         | 
| @@ -288,15 +238,10 @@ def run_inference(model_name, context, question): | |
| 288 | 
             
                    #         **tokenizer_kwargs,
         | 
| 289 | 
             
                    #     ).to(bitnet_model.device)
         | 
| 290 | 
             
                    #     with torch.inference_mode():
         | 
| 291 | 
            -
                    #         # Check interrupt before generation
         | 
| 292 | 
            -
                    #         if generation_interrupt.is_set():
         | 
| 293 | 
            -
                    #             return ""
         | 
| 294 | 
             
                    #         output_sequences = bitnet_model.generate(
         | 
| 295 | 
             
                    #             **formatted,
         | 
| 296 | 
             
                    #             max_new_tokens=512,
         | 
| 297 | 
            -
                    #             stopping_criteria=[interrupt_criteria]
         | 
| 298 | 
             
                    #         )
         | 
| 299 | 
            -
             | 
| 300 | 
             
                    #         result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
         | 
| 301 | 
             
                    else:  # For other models
         | 
| 302 | 
             
                        formatted = pipe.tokenizer.apply_chat_template(
         | 
| @@ -310,7 +255,6 @@ def run_inference(model_name, context, question): | |
| 310 | 
             
                        outputs = pipe(
         | 
| 311 | 
             
                            formatted,
         | 
| 312 | 
             
                            max_new_tokens=512,
         | 
| 313 | 
            -
                            stopping_criteria=[interrupt_criteria],
         | 
| 314 | 
             
                            generation_kwargs={"skip_special_tokens": True}
         | 
| 315 | 
             
                        )
         | 
| 316 | 
             
                        result = outputs[0]["generated_text"][input_length:]
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
            +
            # Keep Dynamo error suppression
         | 
| 3 | 
             
            import torch._dynamo
         | 
| 4 | 
             
            torch._dynamo.config.suppress_errors = True
         | 
| 5 |  | 
|  | |
| 17 | 
             
                BitNetForCausalLM
         | 
| 18 | 
             
            )
         | 
| 19 | 
             
            from .prompts import format_rag_prompt
         | 
| 20 | 
            +
            # Remove interrupt import
         | 
| 21 | 
            +
            # from .shared import generation_interrupt
         | 
| 22 |  | 
| 23 | 
             
            models = {
         | 
| 24 | 
             
                "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
         | 
|  | |
| 48 | 
             
            model_names = list(models.keys())
         | 
| 49 |  | 
| 50 |  | 
| 51 | 
            +
            # Remove interrupt criteria class since we're not using it
         | 
| 52 | 
            +
            # class InterruptCriteria(StoppingCriteria):
         | 
| 53 | 
            +
            #     def __init__(self, interrupt_event):
         | 
| 54 | 
            +
            #         self.interrupt_event = interrupt_event
         | 
| 55 | 
            +
            # 
         | 
| 56 | 
            +
            #     def __call__(self, input_ids, scores, **kwargs):
         | 
| 57 | 
            +
            #         return self.interrupt_event.is_set()
         | 
| 58 |  | 
| 59 |  | 
| 60 | 
             
            @spaces.GPU
         | 
|  | |
| 62 | 
             
                """
         | 
| 63 | 
             
                Generates summaries for the given example using the assigned models sequentially.
         | 
| 64 | 
             
                """
         | 
| 65 | 
            +
                # Remove interrupt checks
         | 
|  | |
|  | |
|  | |
| 66 | 
             
                context_text = ""
         | 
| 67 | 
             
                context_parts = []
         | 
| 68 |  | 
| 69 | 
             
                if "full_contexts" in example and example["full_contexts"]:
         | 
| 70 | 
             
                    for i, ctx in enumerate(example["full_contexts"]):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 71 | 
             
                        content = ""
         | 
| 72 |  | 
| 73 | 
             
                        # Extract content from either dict or string
         | 
|  | |
| 90 |  | 
| 91 | 
             
                question = example.get("question", "")
         | 
| 92 |  | 
|  | |
|  | |
|  | |
|  | |
| 93 | 
             
                print(f"Starting inference for Model A: {model_a_name}")
         | 
| 94 | 
             
                # Run model A
         | 
| 95 | 
             
                summary_a = run_inference(models[model_a_name], context_text, question)
         | 
| 96 |  | 
|  | |
|  | |
|  | |
|  | |
| 97 | 
             
                print(f"Starting inference for Model B: {model_b_name}")
         | 
| 98 | 
             
                # Run model B
         | 
| 99 | 
             
                summary_b = run_inference(models[model_b_name], context_text, question)
         | 
|  | |
| 106 | 
             
            def run_inference(model_name, context, question):
         | 
| 107 | 
             
                """
         | 
| 108 | 
             
                Run inference using the specified model.
         | 
| 109 | 
            +
                Returns the generated text.
         | 
| 110 | 
             
                """
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 111 | 
             
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 112 | 
             
                result = ""
         | 
| 113 | 
             
                tokenizer_kwargs = {
         | 
|  | |
| 126 | 
             
                    if model_name in tokenizer_cache:
         | 
| 127 | 
             
                        tokenizer = tokenizer_cache[model_name]
         | 
| 128 | 
             
                    else:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 129 | 
             
                        # Common arguments for tokenizer loading
         | 
| 130 | 
             
                        tokenizer_load_args = {"padding_side": "left", "token": True}
         | 
| 131 |  | 
|  | |
| 145 | 
             
                    if tokenizer.pad_token is None:
         | 
| 146 | 
             
                        tokenizer.pad_token = tokenizer.eos_token
         | 
| 147 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 148 | 
             
                    print("REACHED HERE BEFORE pipe")
         | 
| 149 | 
             
                    print(f"Loading model {model_name}...")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 150 |  | 
| 151 | 
             
                    if "bitnet" in model_name.lower():
         | 
| 152 | 
             
                        bitnet_model = BitNetForCausalLM.from_pretrained(
         | 
|  | |
| 188 | 
             
                            torch_dtype=torch.bfloat16,
         | 
| 189 | 
             
                        )
         | 
| 190 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 191 | 
             
                    text_input = format_rag_prompt(question, context, accepts_sys)
         | 
| 192 |  | 
| 193 | 
             
                    print(f"Starting generation for {model_name}")
         | 
|  | |
| 196 | 
             
                        result = pipe(
         | 
| 197 | 
             
                            text_input,
         | 
| 198 | 
             
                            max_new_tokens=512,
         | 
|  | |
| 199 | 
             
                            generation_kwargs={"skip_special_tokens": True}
         | 
| 200 | 
             
                        )[0]["generated_text"]
         | 
| 201 |  | 
|  | |
| 219 | 
             
                        prompt_tokens_length = input_ids.shape[1] 
         | 
| 220 |  | 
| 221 | 
             
                        with torch.inference_mode():
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 222 | 
             
                            output_sequences = model.generate(
         | 
| 223 | 
             
                                input_ids=input_ids,
         | 
| 224 | 
             
                                attention_mask=attention_mask,
         | 
| 225 | 
             
                                max_new_tokens=512,
         | 
| 226 | 
             
                                eos_token_id=tokenizer.eos_token_id, 
         | 
| 227 | 
            +
                                pad_token_id=tokenizer.pad_token_id
         | 
|  | |
| 228 | 
             
                            )
         | 
| 229 |  | 
| 230 | 
             
                        generated_token_ids = output_sequences[0][prompt_tokens_length:]
         | 
|  | |
| 238 | 
             
                    #         **tokenizer_kwargs,
         | 
| 239 | 
             
                    #     ).to(bitnet_model.device)
         | 
| 240 | 
             
                    #     with torch.inference_mode():
         | 
|  | |
|  | |
|  | |
| 241 | 
             
                    #         output_sequences = bitnet_model.generate(
         | 
| 242 | 
             
                    #             **formatted,
         | 
| 243 | 
             
                    #             max_new_tokens=512,
         | 
|  | |
| 244 | 
             
                    #         )
         | 
|  | |
| 245 | 
             
                    #         result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
         | 
| 246 | 
             
                    else:  # For other models
         | 
| 247 | 
             
                        formatted = pipe.tokenizer.apply_chat_template(
         | 
|  | |
| 255 | 
             
                        outputs = pipe(
         | 
| 256 | 
             
                            formatted,
         | 
| 257 | 
             
                            max_new_tokens=512,
         | 
|  | |
| 258 | 
             
                            generation_kwargs={"skip_special_tokens": True}
         | 
| 259 | 
             
                        )
         | 
| 260 | 
             
                        result = outputs[0]["generated_text"][input_length:]
         | 
 
			
