Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -17,13 +17,13 @@ def load_model_and_processor(hf_token: str): | |
| 17 | 
             
                    return _model_cache[hf_token]
         | 
| 18 | 
             
                device = torch.device("cpu")
         | 
| 19 | 
             
                model = AutoModelForCausalLM.from_pretrained(
         | 
| 20 | 
            -
                    "microsoft/maira-2", | 
| 21 | 
            -
                    trust_remote_code=True, | 
| 22 | 
             
                    use_auth_token=hf_token
         | 
| 23 | 
             
                )
         | 
| 24 | 
             
                processor = AutoProcessor.from_pretrained(
         | 
| 25 | 
            -
                    "microsoft/maira-2", | 
| 26 | 
            -
                    trust_remote_code=True, | 
| 27 | 
             
                    use_auth_token=hf_token
         | 
| 28 | 
             
                )
         | 
| 29 | 
             
                model.eval()
         | 
| @@ -33,7 +33,7 @@ def load_model_and_processor(hf_token: str): | |
| 33 |  | 
| 34 | 
             
            def get_sample_data() -> dict:
         | 
| 35 | 
             
                """
         | 
| 36 | 
            -
                 | 
| 37 | 
             
                """
         | 
| 38 | 
             
                frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
         | 
| 39 | 
             
                lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
         | 
| @@ -86,7 +86,14 @@ def generate_report(hf_token, frontal, lateral, indication, technique, compariso | |
| 86 | 
             
                    return_tensors="pt",
         | 
| 87 | 
             
                    get_grounding=use_grounding,
         | 
| 88 | 
             
                )
         | 
|  | |
| 89 | 
             
                processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 90 | 
             
                max_tokens = 450 if use_grounding else 300
         | 
| 91 | 
             
                with torch.no_grad():
         | 
| 92 | 
             
                    output_decoding = model.generate(
         | 
| @@ -121,6 +128,12 @@ def run_phrase_grounding(hf_token, frontal, phrase): | |
| 121 | 
             
                    return_tensors="pt",
         | 
| 122 | 
             
                )
         | 
| 123 | 
             
                processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 124 | 
             
                with torch.no_grad():
         | 
| 125 | 
             
                    output_decoding = model.generate(
         | 
| 126 | 
             
                        **processed_inputs,
         | 
| @@ -132,6 +145,7 @@ def run_phrase_grounding(hf_token, frontal, phrase): | |
| 132 | 
             
                prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
         | 
| 133 | 
             
                return prediction
         | 
| 134 |  | 
|  | |
| 135 | 
             
            def login_ui(hf_token):
         | 
| 136 | 
             
                """Authenticate the user by loading the model."""
         | 
| 137 | 
             
                try:
         | 
| @@ -177,14 +191,14 @@ def load_sample_findings(): | |
| 177 | 
             
                sample = get_sample_data()
         | 
| 178 | 
             
                return [
         | 
| 179 | 
             
                    save_temp_image(sample["frontal"]),  # frontal image file path
         | 
| 180 | 
            -
                    save_temp_image(sample["lateral"]), | 
| 181 | 
             
                    sample["indication"],
         | 
| 182 | 
             
                    sample["technique"],
         | 
| 183 | 
             
                    sample["comparison"],
         | 
| 184 | 
             
                    None,  # prior frontal (not used)
         | 
| 185 | 
             
                    None,  # prior lateral (not used)
         | 
| 186 | 
             
                    None,  # prior report (not used)
         | 
| 187 | 
            -
                    False
         | 
| 188 | 
             
                ]
         | 
| 189 |  | 
| 190 | 
             
            def load_sample_phrase():
         | 
| @@ -276,4 +290,4 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo: | |
| 276 | 
             
                            outputs=pg_output
         | 
| 277 | 
             
                        )
         | 
| 278 |  | 
| 279 | 
            -
            demo.launch()
         | 
|  | |
| 17 | 
             
                    return _model_cache[hf_token]
         | 
| 18 | 
             
                device = torch.device("cpu")
         | 
| 19 | 
             
                model = AutoModelForCausalLM.from_pretrained(
         | 
| 20 | 
            +
                    "microsoft/maira-2",
         | 
| 21 | 
            +
                    trust_remote_code=True,
         | 
| 22 | 
             
                    use_auth_token=hf_token
         | 
| 23 | 
             
                )
         | 
| 24 | 
             
                processor = AutoProcessor.from_pretrained(
         | 
| 25 | 
            +
                    "microsoft/maira-2",
         | 
| 26 | 
            +
                    trust_remote_code=True,
         | 
| 27 | 
             
                    use_auth_token=hf_token
         | 
| 28 | 
             
                )
         | 
| 29 | 
             
                model.eval()
         | 
|  | |
| 33 |  | 
| 34 | 
             
            def get_sample_data() -> dict:
         | 
| 35 | 
             
                """
         | 
| 36 | 
            +
                Downloads sample chest X-ray images and associated data.
         | 
| 37 | 
             
                """
         | 
| 38 | 
             
                frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
         | 
| 39 | 
             
                lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
         | 
|  | |
| 86 | 
             
                    return_tensors="pt",
         | 
| 87 | 
             
                    get_grounding=use_grounding,
         | 
| 88 | 
             
                )
         | 
| 89 | 
            +
                # Move all tensors to the CPU
         | 
| 90 | 
             
                processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
         | 
| 91 | 
            +
                # Remove keys containing "image_sizes" to prevent unexpected keyword errors.
         | 
| 92 | 
            +
                processed_inputs = dict(processed_inputs)
         | 
| 93 | 
            +
                keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
         | 
| 94 | 
            +
                for key in keys_to_remove:
         | 
| 95 | 
            +
                    processed_inputs.pop(key, None)
         | 
| 96 | 
            +
                
         | 
| 97 | 
             
                max_tokens = 450 if use_grounding else 300
         | 
| 98 | 
             
                with torch.no_grad():
         | 
| 99 | 
             
                    output_decoding = model.generate(
         | 
|  | |
| 128 | 
             
                    return_tensors="pt",
         | 
| 129 | 
             
                )
         | 
| 130 | 
             
                processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
         | 
| 131 | 
            +
                # Remove keys containing "image_sizes" to prevent unexpected keyword errors.
         | 
| 132 | 
            +
                processed_inputs = dict(processed_inputs)
         | 
| 133 | 
            +
                keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
         | 
| 134 | 
            +
                for key in keys_to_remove:
         | 
| 135 | 
            +
                    processed_inputs.pop(key, None)
         | 
| 136 | 
            +
                
         | 
| 137 | 
             
                with torch.no_grad():
         | 
| 138 | 
             
                    output_decoding = model.generate(
         | 
| 139 | 
             
                        **processed_inputs,
         | 
|  | |
| 145 | 
             
                prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
         | 
| 146 | 
             
                return prediction
         | 
| 147 |  | 
| 148 | 
            +
             | 
| 149 | 
             
            def login_ui(hf_token):
         | 
| 150 | 
             
                """Authenticate the user by loading the model."""
         | 
| 151 | 
             
                try:
         | 
|  | |
| 191 | 
             
                sample = get_sample_data()
         | 
| 192 | 
             
                return [
         | 
| 193 | 
             
                    save_temp_image(sample["frontal"]),  # frontal image file path
         | 
| 194 | 
            +
                    save_temp_image(sample["lateral"]),    # lateral image file path
         | 
| 195 | 
             
                    sample["indication"],
         | 
| 196 | 
             
                    sample["technique"],
         | 
| 197 | 
             
                    sample["comparison"],
         | 
| 198 | 
             
                    None,  # prior frontal (not used)
         | 
| 199 | 
             
                    None,  # prior lateral (not used)
         | 
| 200 | 
             
                    None,  # prior report (not used)
         | 
| 201 | 
            +
                    False  # grounding checkbox default
         | 
| 202 | 
             
                ]
         | 
| 203 |  | 
| 204 | 
             
            def load_sample_phrase():
         | 
|  | |
| 290 | 
             
                            outputs=pg_output
         | 
| 291 | 
             
                        )
         | 
| 292 |  | 
| 293 | 
            +
            demo.launch()
         | 
