arcaputo3 commited on
Commit
6178e01
·
1 Parent(s): 1e3401a

Optimize MPS to use bfloat16 precision

Browse files

After testing, MPS can handle bfloat16 precision without autocast.
This reduces memory usage while maintaining stable output.

Changes from previous commit:
- Use bfloat16 for images on both MPS and CUDA (unified dtype)
- Keep nullcontext() for MPS (no autocast - causes issues)
- CUDA path unchanged (still uses bfloat16 autocast)

Key insight: The row-wise embedding assignment fix was the critical
change. With that in place, bfloat16 works stably on MPS without
needing fp32 precision.

Tested on: macOS 26.0.1, Apple M4 Max, PyTorch 2.9.0

Files changed (1) hide show
  1. modeling_deepseekocr.py +6 -6
modeling_deepseekocr.py CHANGED
@@ -816,8 +816,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
816
 
817
 
818
 
819
- # MPS needs fp32, CUDA can use bfloat16
820
- image_dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
821
  images_list.append(image_transform(global_view).to(image_dtype))
822
 
823
  # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
@@ -865,8 +865,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
865
  # else:
866
  global_view = ImageOps.pad(image, (image_size, image_size),
867
  color=tuple(int(x * 255) for x in image_transform.mean))
868
- # MPS needs fp32, CUDA can use bfloat16
869
- image_dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
870
  images_list.append(image_transform(global_view).to(image_dtype))
871
 
872
  if base_size == 1024:
@@ -932,7 +932,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
932
 
933
  if not eval_mode:
934
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
935
- # MPS: no autocast (pure fp32); CUDA: keep original bfloat16 autocast
936
  autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
937
  with autocast_ctx:
938
  with torch.no_grad():
@@ -952,7 +952,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
952
  )
953
 
954
  else:
955
- # MPS: no autocast (pure fp32); CUDA: keep original bfloat16 autocast
956
  autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
957
  with autocast_ctx:
958
  with torch.no_grad():
 
816
 
817
 
818
 
819
+ # MPS and CUDA both use bfloat16
820
+ image_dtype = torch.bfloat16
821
  images_list.append(image_transform(global_view).to(image_dtype))
822
 
823
  # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
 
865
  # else:
866
  global_view = ImageOps.pad(image, (image_size, image_size),
867
  color=tuple(int(x * 255) for x in image_transform.mean))
868
+ # MPS and CUDA both use bfloat16
869
+ image_dtype = torch.bfloat16
870
  images_list.append(image_transform(global_view).to(image_dtype))
871
 
872
  if base_size == 1024:
 
932
 
933
  if not eval_mode:
934
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
935
+ # MPS: no autocast (pure bfloat16); CUDA: bfloat16 autocast
936
  autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
937
  with autocast_ctx:
938
  with torch.no_grad():
 
952
  )
953
 
954
  else:
955
+ # MPS: no autocast (pure bfloat16); CUDA: bfloat16 autocast
956
  autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
957
  with autocast_ctx:
958
  with torch.no_grad():