enable intel XPU platform (#7)
Browse files- add intel xpu platform support (feaa3d5aef9b856f0f68ce753342014b21ea6266)
Co-authored-by: Liu,Kaixuan <Kaixuanliu@users.noreply.huggingface.co>
- modeling_cogvlm.py +9 -3
- util.py +7 -1
- visual.py +2 -0
    	
        modeling_cogvlm.py
    CHANGED
    
    | @@ -8,6 +8,7 @@ from torch import nn | |
| 8 | 
             
            from torch.nn import CrossEntropyLoss
         | 
| 9 | 
             
            from torchvision import transforms
         | 
| 10 | 
             
            from einops import rearrange
         | 
|  | |
| 11 | 
             
            from transformers import PreTrainedModel, PreTrainedTokenizer
         | 
| 12 | 
             
            from transformers.utils.logging import get_logger
         | 
| 13 | 
             
            from transformers.activations import ACT2FN
         | 
| @@ -723,9 +724,14 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel): | |
| 723 | 
             
                        standardize_cache_format: bool = False,
         | 
| 724 | 
             
                ) -> Dict[str, Any]:
         | 
| 725 | 
             
                    # update past_key_values
         | 
| 726 | 
            -
                     | 
| 727 | 
            -
                         | 
| 728 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 729 | 
             
                    model_kwargs[cache_name] = cache
         | 
| 730 |  | 
| 731 | 
             
                    if getattr(outputs, "state", None) is not None:
         | 
|  | |
| 8 | 
             
            from torch.nn import CrossEntropyLoss
         | 
| 9 | 
             
            from torchvision import transforms
         | 
| 10 | 
             
            from einops import rearrange
         | 
| 11 | 
            +
            import transformers
         | 
| 12 | 
             
            from transformers import PreTrainedModel, PreTrainedTokenizer
         | 
| 13 | 
             
            from transformers.utils.logging import get_logger
         | 
| 14 | 
             
            from transformers.activations import ACT2FN
         | 
|  | |
| 724 | 
             
                        standardize_cache_format: bool = False,
         | 
| 725 | 
             
                ) -> Dict[str, Any]:
         | 
| 726 | 
             
                    # update past_key_values
         | 
| 727 | 
            +
                    if transformers.__version__ >= "4.44.0":
         | 
| 728 | 
            +
                        cache_name, cache = self._extract_past_from_model_output(
         | 
| 729 | 
            +
                            outputs
         | 
| 730 | 
            +
                        )
         | 
| 731 | 
            +
                    else:
         | 
| 732 | 
            +
                        cache_name, cache = self._extract_past_from_model_output(
         | 
| 733 | 
            +
                            outputs, standardize_cache_format=standardize_cache_format
         | 
| 734 | 
            +
                        )
         | 
| 735 | 
             
                    model_kwargs[cache_name] = cache
         | 
| 736 |  | 
| 737 | 
             
                    if getattr(outputs, "state", None) is not None:
         | 
    	
        util.py
    CHANGED
    
    | @@ -7,6 +7,10 @@ import torch.nn.functional as F | |
| 7 | 
             
            import triton
         | 
| 8 | 
             
            import triton.language as tl
         | 
| 9 |  | 
|  | |
|  | |
|  | |
|  | |
| 10 |  | 
| 11 | 
             
            @triton.jit
         | 
| 12 | 
             
            def rotary_kernel(
         | 
| @@ -197,7 +201,9 @@ def apply_rotary( | |
| 197 |  | 
| 198 | 
             
                # Need this, otherwise Triton tries to launch from cuda:0 and we get
         | 
| 199 | 
             
                # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
         | 
| 200 | 
            -
                 | 
|  | |
|  | |
| 201 | 
             
                    rotary_kernel[grid](
         | 
| 202 | 
             
                        output,  # data ptrs
         | 
| 203 | 
             
                        x,
         | 
|  | |
| 7 | 
             
            import triton
         | 
| 8 | 
             
            import triton.language as tl
         | 
| 9 |  | 
| 10 | 
            +
            device_contexts = {
         | 
| 11 | 
            +
                'cuda': torch.cuda.device,
         | 
| 12 | 
            +
                'xpu': torch.xpu.device
         | 
| 13 | 
            +
            }
         | 
| 14 |  | 
| 15 | 
             
            @triton.jit
         | 
| 16 | 
             
            def rotary_kernel(
         | 
|  | |
| 201 |  | 
| 202 | 
             
                # Need this, otherwise Triton tries to launch from cuda:0 and we get
         | 
| 203 | 
             
                # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
         | 
| 204 | 
            +
                device_type = x.device.type
         | 
| 205 | 
            +
                assert device_type in device_contexts
         | 
| 206 | 
            +
                with device_contexts[device_type](x.device.index):
         | 
| 207 | 
             
                    rotary_kernel[grid](
         | 
| 208 | 
             
                        output,  # data ptrs
         | 
| 209 | 
             
                        x,
         | 
    	
        visual.py
    CHANGED
    
    | @@ -75,6 +75,8 @@ class Attention(nn.Module): | |
| 75 | 
             
                    out = out.transpose(2, 1)
         | 
| 76 | 
             
                    # breakpoint()
         | 
| 77 | 
             
                    # output = self.dense(out.reshape(B, L, -1))
         | 
|  | |
|  | |
| 78 | 
             
                    output = self.dense(out.view(B, L, -1))
         | 
| 79 | 
             
                    output = self.output_dropout(output)
         | 
| 80 | 
             
                    return output
         | 
|  | |
| 75 | 
             
                    out = out.transpose(2, 1)
         | 
| 76 | 
             
                    # breakpoint()
         | 
| 77 | 
             
                    # output = self.dense(out.reshape(B, L, -1))
         | 
| 78 | 
            +
                    if not out.is_contiguous():
         | 
| 79 | 
            +
                        out = out.contiguous()
         | 
| 80 | 
             
                    output = self.dense(out.view(B, L, -1))
         | 
| 81 | 
             
                    output = self.output_dropout(output)
         | 
| 82 | 
             
                    return output
         | 
