Optimize MPS to use bfloat16 precision
Browse filesAfter testing, MPS can handle bfloat16 precision without autocast.
This reduces memory usage while maintaining stable output.
Changes from previous commit:
- Use bfloat16 for images on both MPS and CUDA (unified dtype)
- Keep nullcontext() for MPS (no autocast - causes issues)
- CUDA path unchanged (still uses bfloat16 autocast)
Key insight: The row-wise embedding assignment fix was the critical
change. With that in place, bfloat16 works stably on MPS without
needing fp32 precision.
Tested on: macOS 26.0.1, Apple M4 Max, PyTorch 2.9.0
- modeling_deepseekocr.py +6 -6
    	
        modeling_deepseekocr.py
    CHANGED
    
    | @@ -816,8 +816,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 816 |  | 
| 817 |  | 
| 818 |  | 
| 819 | 
            -
                            # MPS  | 
| 820 | 
            -
                            image_dtype = torch. | 
| 821 | 
             
                            images_list.append(image_transform(global_view).to(image_dtype))
         | 
| 822 |  | 
| 823 | 
             
                            # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
         | 
| @@ -865,8 +865,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 865 | 
             
                            # else:
         | 
| 866 | 
             
                            global_view = ImageOps.pad(image, (image_size, image_size),
         | 
| 867 | 
             
                                                    color=tuple(int(x * 255) for x in image_transform.mean))
         | 
| 868 | 
            -
                            # MPS  | 
| 869 | 
            -
                            image_dtype = torch. | 
| 870 | 
             
                            images_list.append(image_transform(global_view).to(image_dtype))
         | 
| 871 |  | 
| 872 | 
             
                            if base_size == 1024:
         | 
| @@ -932,7 +932,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 932 |  | 
| 933 | 
             
                    if not eval_mode:
         | 
| 934 | 
             
                        streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
         | 
| 935 | 
            -
                        # MPS: no autocast (pure  | 
| 936 | 
             
                        autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
         | 
| 937 | 
             
                        with autocast_ctx:
         | 
| 938 | 
             
                            with torch.no_grad():
         | 
| @@ -952,7 +952,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 952 | 
             
                                    )
         | 
| 953 |  | 
| 954 | 
             
                    else:
         | 
| 955 | 
            -
                        # MPS: no autocast (pure  | 
| 956 | 
             
                        autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
         | 
| 957 | 
             
                        with autocast_ctx:
         | 
| 958 | 
             
                            with torch.no_grad():
         | 
|  | |
| 816 |  | 
| 817 |  | 
| 818 |  | 
| 819 | 
            +
                            # MPS and CUDA both use bfloat16
         | 
| 820 | 
            +
                            image_dtype = torch.bfloat16
         | 
| 821 | 
             
                            images_list.append(image_transform(global_view).to(image_dtype))
         | 
| 822 |  | 
| 823 | 
             
                            # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
         | 
|  | |
| 865 | 
             
                            # else:
         | 
| 866 | 
             
                            global_view = ImageOps.pad(image, (image_size, image_size),
         | 
| 867 | 
             
                                                    color=tuple(int(x * 255) for x in image_transform.mean))
         | 
| 868 | 
            +
                            # MPS and CUDA both use bfloat16
         | 
| 869 | 
            +
                            image_dtype = torch.bfloat16
         | 
| 870 | 
             
                            images_list.append(image_transform(global_view).to(image_dtype))
         | 
| 871 |  | 
| 872 | 
             
                            if base_size == 1024:
         | 
|  | |
| 932 |  | 
| 933 | 
             
                    if not eval_mode:
         | 
| 934 | 
             
                        streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
         | 
| 935 | 
            +
                        # MPS: no autocast (pure bfloat16); CUDA: bfloat16 autocast
         | 
| 936 | 
             
                        autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
         | 
| 937 | 
             
                        with autocast_ctx:
         | 
| 938 | 
             
                            with torch.no_grad():
         | 
|  | |
| 952 | 
             
                                    )
         | 
| 953 |  | 
| 954 | 
             
                    else:
         | 
| 955 | 
            +
                        # MPS: no autocast (pure bfloat16); CUDA: bfloat16 autocast
         | 
| 956 | 
             
                        autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
         | 
| 957 | 
             
                        with autocast_ctx:
         | 
| 958 | 
             
                            with torch.no_grad():
         | 
