update kernels
Browse files
    	
        kernels/cache_autogptq_cuda_256.cpp → cache_autogptq_cuda_256.cpp
    RENAMED
    
    | 
            File without changes
         | 
    	
        kernels/cache_autogptq_cuda_kernel_256.cu → cache_autogptq_cuda_kernel_256.cu
    RENAMED
    
    | 
            File without changes
         | 
    	
        kernels/cpp_kernels.py → cpp_kernels.py
    RENAMED
    
    | @@ -50,6 +50,6 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags): | |
| 50 |  | 
| 51 | 
             
            extra_flags = []
         | 
| 52 |  | 
| 53 | 
            -
            cache_autogptq_cuda_256_sources = ["./ | 
| 54 | 
            -
                       "./ | 
| 55 | 
             
            cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
         | 
|  | |
| 50 |  | 
| 51 | 
             
            extra_flags = []
         | 
| 52 |  | 
| 53 | 
            +
            cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp",
         | 
| 54 | 
            +
                       "./cache_autogptq_cuda_kernel_256.cu"]
         | 
| 55 | 
             
            cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
         | 
    	
        modeling_qwen.py
    CHANGED
    
    | @@ -32,11 +32,6 @@ except ImportError: | |
| 32 | 
             
                rearrange = None
         | 
| 33 | 
             
            from torch import nn
         | 
| 34 |  | 
| 35 | 
            -
            try:
         | 
| 36 | 
            -
                from kernels.cpp_kernels import cache_autogptq_cuda_256
         | 
| 37 | 
            -
            except ImportError:
         | 
| 38 | 
            -
                cache_autogptq_cuda_256 = None
         | 
| 39 | 
            -
             | 
| 40 | 
             
            SUPPORT_CUDA = torch.cuda.is_available()
         | 
| 41 | 
             
            SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
         | 
| 42 | 
             
            SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
         | 
| @@ -294,14 +289,21 @@ class QWenAttention(nn.Module): | |
| 294 | 
             
                    self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
         | 
| 295 | 
             
                    self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
         | 
| 296 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 297 | 
             
                def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
         | 
| 298 | 
             
                    device = query.device
         | 
| 299 | 
             
                    if self.use_cache_quantization:
         | 
| 300 | 
             
                        qk, qk_scale, qk_zero = key
         | 
| 301 | 
            -
                        if self.use_cache_kernel and  | 
| 302 | 
             
                            shape = query.shape[:-1] + (qk.shape[-2],)
         | 
| 303 | 
             
                            attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
         | 
| 304 | 
            -
                             | 
| 305 | 
             
                                query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
         | 
| 306 | 
             
                                qk.transpose(-1, -2).contiguous(),
         | 
| 307 | 
             
                                attn_weights,
         | 
| @@ -353,10 +355,10 @@ class QWenAttention(nn.Module): | |
| 353 |  | 
| 354 | 
             
                    if self.use_cache_quantization:
         | 
| 355 | 
             
                        qv, qv_scale, qv_zero = value
         | 
| 356 | 
            -
                        if self.use_cache_kernel and  | 
| 357 | 
             
                            shape = attn_weights.shape[:-1] + (query.shape[-1],)
         | 
| 358 | 
             
                            attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
         | 
| 359 | 
            -
                             | 
| 360 | 
             
                                attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
         | 
| 361 | 
             
                                qv.contiguous(),  # dtype: int32
         | 
| 362 | 
             
                                attn_output,
         | 
| @@ -1022,15 +1024,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): | |
| 1022 | 
             
                    if config.use_flash_attn:
         | 
| 1023 | 
             
                        _import_flash_attn()
         | 
| 1024 |  | 
| 1025 | 
            -
             | 
| 1026 | 
            -
                    if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
         | 
| 1027 | 
            -
                        config.use_flash_attn = False
         | 
| 1028 | 
            -
                        if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
         | 
| 1029 | 
            -
                            try:
         | 
| 1030 | 
            -
                                from kernels.cpp_kernels import cache_autogptq_cuda_256
         | 
| 1031 | 
            -
                            except ImportError:
         | 
| 1032 | 
            -
                                cache_autogptq_cuda_256 = None
         | 
| 1033 | 
            -
             | 
| 1034 | 
             
                    self.transformer = QWenModel(config)
         | 
| 1035 | 
             
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         | 
| 1036 |  | 
|  | |
| 32 | 
             
                rearrange = None
         | 
| 33 | 
             
            from torch import nn
         | 
| 34 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 35 | 
             
            SUPPORT_CUDA = torch.cuda.is_available()
         | 
| 36 | 
             
            SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
         | 
| 37 | 
             
            SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
         | 
|  | |
| 289 | 
             
                    self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
         | 
| 290 | 
             
                    self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
         | 
| 291 |  | 
| 292 | 
            +
                    if config.use_cache_quantization and config.use_cache_kernel:
         | 
| 293 | 
            +
                        from .cpp_kernels import cache_autogptq_cuda_256
         | 
| 294 | 
            +
                        try:
         | 
| 295 | 
            +
                            self.cache_kernels = cache_autogptq_cuda_256
         | 
| 296 | 
            +
                        except ImportError:
         | 
| 297 | 
            +
                            self.cache_kernels = None
         | 
| 298 | 
            +
             | 
| 299 | 
             
                def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
         | 
| 300 | 
             
                    device = query.device
         | 
| 301 | 
             
                    if self.use_cache_quantization:
         | 
| 302 | 
             
                        qk, qk_scale, qk_zero = key
         | 
| 303 | 
            +
                        if self.use_cache_kernel and self.cache_kernels is not None:
         | 
| 304 | 
             
                            shape = query.shape[:-1] + (qk.shape[-2],)
         | 
| 305 | 
             
                            attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
         | 
| 306 | 
            +
                            self.cache_kernels.vecquant8matmul_batched_faster_old(
         | 
| 307 | 
             
                                query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
         | 
| 308 | 
             
                                qk.transpose(-1, -2).contiguous(),
         | 
| 309 | 
             
                                attn_weights,
         | 
|  | |
| 355 |  | 
| 356 | 
             
                    if self.use_cache_quantization:
         | 
| 357 | 
             
                        qv, qv_scale, qv_zero = value
         | 
| 358 | 
            +
                        if self.use_cache_kernel and self.cache_kernels is not None:
         | 
| 359 | 
             
                            shape = attn_weights.shape[:-1] + (query.shape[-1],)
         | 
| 360 | 
             
                            attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
         | 
| 361 | 
            +
                            self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
         | 
| 362 | 
             
                                attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
         | 
| 363 | 
             
                                qv.contiguous(),  # dtype: int32
         | 
| 364 | 
             
                                attn_output,
         | 
|  | |
| 1024 | 
             
                    if config.use_flash_attn:
         | 
| 1025 | 
             
                        _import_flash_attn()
         | 
| 1026 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1027 | 
             
                    self.transformer = QWenModel(config)
         | 
| 1028 | 
             
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         | 
| 1029 |  | 

