fixing hardcoded cuda() for cpu inference
#21
by
						
alexgambashidze
	
							
						- opened
							
					
- modeling_deepseekocr.py +16 -12
    	
        modeling_deepseekocr.py
    CHANGED
    
    | @@ -502,7 +502,7 @@ class DeepseekOCRModel(DeepseekV2Model): | |
| 502 | 
             
                                images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
         | 
| 503 | 
             
                                # exit()
         | 
| 504 |  | 
| 505 | 
            -
                                inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1). | 
| 506 |  | 
| 507 | 
             
                            idx += 1
         | 
| 508 |  | 
| @@ -703,6 +703,10 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 703 | 
             
                def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
         | 
| 704 | 
             
                    self.disable_torch_init()
         | 
| 705 |  | 
|  | |
|  | |
|  | |
|  | |
| 706 | 
             
                    os.makedirs(output_path, exist_ok=True)
         | 
| 707 | 
             
                    os.makedirs(f'{output_path}/images', exist_ok=True)
         | 
| 708 |  | 
| @@ -911,12 +915,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 911 |  | 
| 912 | 
             
                    if not eval_mode:
         | 
| 913 | 
             
                        streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
         | 
| 914 | 
            -
                        with torch.autocast( | 
| 915 | 
             
                            with torch.no_grad():
         | 
| 916 | 
             
                                output_ids = self.generate(
         | 
| 917 | 
            -
                                    input_ids.unsqueeze(0). | 
| 918 | 
            -
                                    images=[(images_crop. | 
| 919 | 
            -
                                    images_seq_mask = images_seq_mask.unsqueeze(0). | 
| 920 | 
             
                                    images_spatial_crop = images_spatial_crop,
         | 
| 921 | 
             
                                    # do_sample=False,
         | 
| 922 | 
             
                                    # num_beams = 1,
         | 
| @@ -929,12 +933,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 929 | 
             
                                    )
         | 
| 930 |  | 
| 931 | 
             
                    else:
         | 
| 932 | 
            -
                        with torch.autocast( | 
| 933 | 
             
                            with torch.no_grad():
         | 
| 934 | 
             
                                output_ids = self.generate(
         | 
| 935 | 
            -
                                    input_ids.unsqueeze(0). | 
| 936 | 
            -
                                    images=[(images_crop. | 
| 937 | 
            -
                                    images_seq_mask = images_seq_mask.unsqueeze(0). | 
| 938 | 
             
                                    images_spatial_crop = images_spatial_crop,
         | 
| 939 | 
             
                                    # do_sample=False,
         | 
| 940 | 
             
                                    # num_beams = 1,
         | 
| @@ -947,7 +951,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 947 |  | 
| 948 |  | 
| 949 | 
             
                    if '<image>' in conversation[0]['content'] and eval_mode:
         | 
| 950 | 
            -
                            outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0). | 
| 951 | 
             
                            stop_str = '<|end▁of▁sentence|>'
         | 
| 952 | 
             
                            if outputs.endswith(stop_str):
         | 
| 953 | 
             
                                outputs = outputs[:-len(stop_str)]
         | 
| @@ -957,7 +961,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 957 | 
             
                            return outputs
         | 
| 958 |  | 
| 959 | 
             
                    if '<image>' in conversation[0]['content'] and test_compress:
         | 
| 960 | 
            -
                        outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0). | 
| 961 | 
             
                        pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
         | 
| 962 | 
             
                        print('='*50)
         | 
| 963 | 
             
                        print('image size: ', (w, h))
         | 
| @@ -968,7 +972,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 968 |  | 
| 969 |  | 
| 970 | 
             
                    if '<image>' in conversation[0]['content'] and save_results:
         | 
| 971 | 
            -
                        outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0). | 
| 972 | 
             
                        stop_str = '<|end▁of▁sentence|>'
         | 
| 973 |  | 
| 974 | 
             
                        print('='*15 + 'save results:' + '='*15)
         | 
|  | |
| 502 | 
             
                                images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
         | 
| 503 | 
             
                                # exit()
         | 
| 504 |  | 
| 505 | 
            +
                                inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).to(inputs_embeds.device), images_in_this_batch)
         | 
| 506 |  | 
| 507 | 
             
                            idx += 1
         | 
| 508 |  | 
|  | |
| 703 | 
             
                def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
         | 
| 704 | 
             
                    self.disable_torch_init()
         | 
| 705 |  | 
| 706 | 
            +
                    # Get the device from model
         | 
| 707 | 
            +
                    device = next(self.parameters()).device
         | 
| 708 | 
            +
                    device_type = 'cuda' if device.type == 'cuda' else 'cpu'
         | 
| 709 | 
            +
             | 
| 710 | 
             
                    os.makedirs(output_path, exist_ok=True)
         | 
| 711 | 
             
                    os.makedirs(f'{output_path}/images', exist_ok=True)
         | 
| 712 |  | 
|  | |
| 915 |  | 
| 916 | 
             
                    if not eval_mode:
         | 
| 917 | 
             
                        streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
         | 
| 918 | 
            +
                        with torch.autocast(device_type, dtype=torch.bfloat16):
         | 
| 919 | 
             
                            with torch.no_grad():
         | 
| 920 | 
             
                                output_ids = self.generate(
         | 
| 921 | 
            +
                                    input_ids.unsqueeze(0).to(device),
         | 
| 922 | 
            +
                                    images=[(images_crop.to(device), images_ori.to(device))],
         | 
| 923 | 
            +
                                    images_seq_mask = images_seq_mask.unsqueeze(0).to(device),
         | 
| 924 | 
             
                                    images_spatial_crop = images_spatial_crop,
         | 
| 925 | 
             
                                    # do_sample=False,
         | 
| 926 | 
             
                                    # num_beams = 1,
         | 
|  | |
| 933 | 
             
                                    )
         | 
| 934 |  | 
| 935 | 
             
                    else:
         | 
| 936 | 
            +
                        with torch.autocast(device_type, dtype=torch.bfloat16):
         | 
| 937 | 
             
                            with torch.no_grad():
         | 
| 938 | 
             
                                output_ids = self.generate(
         | 
| 939 | 
            +
                                    input_ids.unsqueeze(0).to(device),
         | 
| 940 | 
            +
                                    images=[(images_crop.to(device), images_ori.to(device))],
         | 
| 941 | 
            +
                                    images_seq_mask = images_seq_mask.unsqueeze(0).to(device),
         | 
| 942 | 
             
                                    images_spatial_crop = images_spatial_crop,
         | 
| 943 | 
             
                                    # do_sample=False,
         | 
| 944 | 
             
                                    # num_beams = 1,
         | 
|  | |
| 951 |  | 
| 952 |  | 
| 953 | 
             
                    if '<image>' in conversation[0]['content'] and eval_mode:
         | 
| 954 | 
            +
                            outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(device).shape[1]:])
         | 
| 955 | 
             
                            stop_str = '<|end▁of▁sentence|>'
         | 
| 956 | 
             
                            if outputs.endswith(stop_str):
         | 
| 957 | 
             
                                outputs = outputs[:-len(stop_str)]
         | 
|  | |
| 961 | 
             
                            return outputs
         | 
| 962 |  | 
| 963 | 
             
                    if '<image>' in conversation[0]['content'] and test_compress:
         | 
| 964 | 
            +
                        outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(device).shape[1]:])
         | 
| 965 | 
             
                        pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
         | 
| 966 | 
             
                        print('='*50)
         | 
| 967 | 
             
                        print('image size: ', (w, h))
         | 
|  | |
| 972 |  | 
| 973 |  | 
| 974 | 
             
                    if '<image>' in conversation[0]['content'] and save_results:
         | 
| 975 | 
            +
                        outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(device).shape[1]:])
         | 
| 976 | 
             
                        stop_str = '<|end▁of▁sentence|>'
         | 
| 977 |  | 
| 978 | 
             
                        print('='*15 + 'save results:' + '='*15)
         | 
