Commit
·
c1e53ae
1
Parent(s):
acd39ac
Add support for XPU to run gpt-oss
Browse files- torch-ext/triton_kernels/matmul_ogs.py +2 -1
- torch-ext/triton_kernels/matmul_ogs_details/_common.py +13 -1
- torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py +6 -5
- torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +3 -3
- torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py +80 -1
- torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py +41 -0
- torch-ext/triton_kernels/numerics_details/flexpoint.py +2 -1
- torch-ext/triton_kernels/target_info.py +47 -26
torch-ext/triton_kernels/matmul_ogs.py
CHANGED
|
@@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
| 602 |
betas = None,
|
| 603 |
gammas = None,
|
| 604 |
round_x = None, round_y = None,
|
|
|
|
| 605 |
):
|
| 606 |
is_input_batched = x.ndim == 3
|
| 607 |
assert x.dtype.itemsize > 1
|
|
@@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
| 641 |
else:
|
| 642 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
| 643 |
batch = i if is_input_batched else 0
|
| 644 |
-
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=
|
| 645 |
w[i].float())
|
| 646 |
if bias is not None:
|
| 647 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
|
|
|
| 602 |
betas = None,
|
| 603 |
gammas = None,
|
| 604 |
round_x = None, round_y = None,
|
| 605 |
+
device: str = "cuda",
|
| 606 |
):
|
| 607 |
is_input_batched = x.ndim == 3
|
| 608 |
assert x.dtype.itemsize > 1
|
|
|
|
| 642 |
else:
|
| 643 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
| 644 |
batch = i if is_input_batched else 0
|
| 645 |
+
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(),
|
| 646 |
w[i].float())
|
| 647 |
if bias is not None:
|
| 648 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
torch-ext/triton_kernels/matmul_ogs_details/_common.py
CHANGED
|
@@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor
|
|
| 7 |
# -----------------------------------------------------------------------------
|
| 8 |
# Utilities
|
| 9 |
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
| 14 |
mapping = {
|
| 15 |
tl.float16: "fp16",
|
|
|
|
| 7 |
# -----------------------------------------------------------------------------
|
| 8 |
# Utilities
|
| 9 |
# -----------------------------------------------------------------------------
|
| 10 |
+
try:
|
| 11 |
+
_ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
|
| 12 |
+
_parts = _ver_str.split(".")
|
| 13 |
+
_ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
|
| 14 |
+
except Exception:
|
| 15 |
+
_ver_tuple = (0, 0, 0)
|
| 16 |
|
| 17 |
+
if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"):
|
| 18 |
+
_constexpr_function = triton.constexpr_function
|
| 19 |
+
else:
|
| 20 |
+
_constexpr_function = tl.constexpr_function
|
| 21 |
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@_constexpr_function
|
| 25 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
| 26 |
mapping = {
|
| 27 |
tl.float16: "fp16",
|
torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py
CHANGED
|
@@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
|
|
| 4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
| 6 |
from ..target_info import is_hip as _is_hip
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
# fmt: off
|
| 10 |
-
@
|
| 11 |
def is_hip():
|
| 12 |
return _is_hip()
|
| 13 |
|
| 14 |
|
| 15 |
-
@
|
| 16 |
def cuda_capability_geq(x, y):
|
| 17 |
return _cuda_capability_geq(x, y)
|
| 18 |
|
| 19 |
|
| 20 |
-
@
|
| 21 |
def log2(n):
|
| 22 |
return len(bin(n)) - 3
|
| 23 |
|
| 24 |
|
| 25 |
-
@
|
| 26 |
def _permute_to_end_order(n: int, axis: int):
|
| 27 |
"""
|
| 28 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
|
@@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args):
|
|
| 105 |
return ret
|
| 106 |
|
| 107 |
|
| 108 |
-
@
|
| 109 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
| 110 |
"""
|
| 111 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
|
|
|
| 4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
| 6 |
from ..target_info import is_hip as _is_hip
|
| 7 |
+
from ._common import _constexpr_function
|
| 8 |
|
| 9 |
|
| 10 |
# fmt: off
|
| 11 |
+
@_constexpr_function
|
| 12 |
def is_hip():
|
| 13 |
return _is_hip()
|
| 14 |
|
| 15 |
|
| 16 |
+
@_constexpr_function
|
| 17 |
def cuda_capability_geq(x, y):
|
| 18 |
return _cuda_capability_geq(x, y)
|
| 19 |
|
| 20 |
|
| 21 |
+
@_constexpr_function
|
| 22 |
def log2(n):
|
| 23 |
return len(bin(n)) - 3
|
| 24 |
|
| 25 |
|
| 26 |
+
@_constexpr_function
|
| 27 |
def _permute_to_end_order(n: int, axis: int):
|
| 28 |
"""
|
| 29 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
|
|
|
| 106 |
return ret
|
| 107 |
|
| 108 |
|
| 109 |
+
@_constexpr_function
|
| 110 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
| 111 |
"""
|
| 112 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
CHANGED
|
@@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import (
|
|
| 12 |
compute_scale,
|
| 13 |
)
|
| 14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 15 |
-
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
|
| 16 |
|
| 17 |
|
| 18 |
-
@
|
| 19 |
def cuda_capability_geq(major, minor):
|
| 20 |
return target_info.cuda_capability_geq(major, minor)
|
| 21 |
|
| 22 |
-
@
|
| 23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
| 24 |
if isinstance(tensor_or_desc, tl.tensor):
|
| 25 |
return tensor_or_desc.dtype.element_ty
|
|
|
|
| 12 |
compute_scale,
|
| 13 |
)
|
| 14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 15 |
+
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function
|
| 16 |
|
| 17 |
|
| 18 |
+
@_constexpr_function
|
| 19 |
def cuda_capability_geq(major, minor):
|
| 20 |
return target_info.cuda_capability_geq(major, minor)
|
| 21 |
|
| 22 |
+
@_constexpr_function
|
| 23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
| 24 |
if isinstance(tensor_or_desc, tl.tensor):
|
| 25 |
return tensor_or_desc.dtype.element_ty
|
torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py
CHANGED
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
| 4 |
import triton
|
| 5 |
from ..target_info import get_cdna_version
|
| 6 |
import torch
|
| 7 |
-
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
|
@@ -30,6 +30,83 @@ class OptFlags:
|
|
| 30 |
raise ValueError("Not supported")
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def make_default_opt_flags_amd(
|
| 35 |
out_dtype,
|
|
@@ -292,6 +369,8 @@ def make_opt_flags(
|
|
| 292 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
| 293 |
_opt_flags_constraints]
|
| 294 |
backend = triton.runtime.driver.active.get_current_target().backend
|
|
|
|
|
|
|
| 295 |
if backend == "hip":
|
| 296 |
return make_default_opt_flags_amd(*args)
|
| 297 |
if backend == "cuda":
|
|
|
|
| 4 |
import triton
|
| 5 |
from ..target_info import get_cdna_version
|
| 6 |
import torch
|
| 7 |
+
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
|
|
|
| 30 |
raise ValueError("Not supported")
|
| 31 |
|
| 32 |
|
| 33 |
+
def make_default_opt_flags_intel(
|
| 34 |
+
out_dtype,
|
| 35 |
+
lhs_dtype,
|
| 36 |
+
rhs_dtype,
|
| 37 |
+
precision_config,
|
| 38 |
+
m,
|
| 39 |
+
n,
|
| 40 |
+
k,
|
| 41 |
+
routing_data,
|
| 42 |
+
can_use_persistent_tma,
|
| 43 |
+
can_use_fused_scatter,
|
| 44 |
+
enforce_bitwise_invariance,
|
| 45 |
+
epilogue_effective_itemsize,
|
| 46 |
+
constraints,
|
| 47 |
+
):
|
| 48 |
+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
|
| 49 |
+
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
|
| 50 |
+
# tokens per expert
|
| 51 |
+
if routing_data is None:
|
| 52 |
+
tokens_per_expt = m
|
| 53 |
+
elif routing_data.expected_tokens_per_expt is None:
|
| 54 |
+
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
|
| 55 |
+
else:
|
| 56 |
+
tokens_per_expt = routing_data.expected_tokens_per_expt
|
| 57 |
+
# pid swizzling
|
| 58 |
+
group_m = 8
|
| 59 |
+
xcd_swizzle = 1
|
| 60 |
+
# block_m
|
| 61 |
+
if constraints.get("block_m", None):
|
| 62 |
+
block_m = constraints["block_m"]
|
| 63 |
+
elif enforce_bitwise_invariance:
|
| 64 |
+
block_m = 128
|
| 65 |
+
else:
|
| 66 |
+
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
|
| 67 |
+
# block n
|
| 68 |
+
block_n = opt_flags_intel.compute_block_n(n)
|
| 69 |
+
# is_persistent
|
| 70 |
+
is_persistent = constraints.get("is_persistent", False)
|
| 71 |
+
# block k
|
| 72 |
+
if constraints.get("block_k", None) is not None:
|
| 73 |
+
block_k = constraints["block_k"]
|
| 74 |
+
else:
|
| 75 |
+
block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config)
|
| 76 |
+
# split_k
|
| 77 |
+
if constraints.get("split_k", None) is not None:
|
| 78 |
+
split_k = constraints["split_k"]
|
| 79 |
+
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
|
| 80 |
+
split_k = 1
|
| 81 |
+
else:
|
| 82 |
+
estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n)
|
| 83 |
+
split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size)
|
| 84 |
+
|
| 85 |
+
epilogue_subtile = constraints.get('epilogue_subtile', None)
|
| 86 |
+
if epilogue_subtile is None:
|
| 87 |
+
epilogue_subtile = 1
|
| 88 |
+
|
| 89 |
+
ret = OptFlags(
|
| 90 |
+
block_m=block_m,
|
| 91 |
+
block_n=block_n,
|
| 92 |
+
block_k=block_k,
|
| 93 |
+
num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
|
| 94 |
+
num_stages=constraints.get("num_stages", 2),
|
| 95 |
+
fused_scatter=constraints.get('fused_scatter', False),
|
| 96 |
+
group_m=group_m,
|
| 97 |
+
xcd_swizzle=xcd_swizzle,
|
| 98 |
+
w_cache_modifier=None,
|
| 99 |
+
split_k=split_k,
|
| 100 |
+
is_persistent=is_persistent,
|
| 101 |
+
epilogue_subtile=epilogue_subtile,
|
| 102 |
+
arch=None,
|
| 103 |
+
target_kernel_kwargs=dict(),
|
| 104 |
+
idle_sms=0,
|
| 105 |
+
)
|
| 106 |
+
# check constraints
|
| 107 |
+
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
|
| 108 |
+
return ret
|
| 109 |
+
|
| 110 |
|
| 111 |
def make_default_opt_flags_amd(
|
| 112 |
out_dtype,
|
|
|
|
| 369 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
| 370 |
_opt_flags_constraints]
|
| 371 |
backend = triton.runtime.driver.active.get_current_target().backend
|
| 372 |
+
if backend == "xpu":
|
| 373 |
+
return make_default_opt_flags_intel(*args)
|
| 374 |
if backend == "hip":
|
| 375 |
return make_default_opt_flags_amd(*args)
|
| 376 |
if backend == "cuda":
|
torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compute_grid_size(routing_data, m, n, block_m, block_n):
|
| 6 |
+
if routing_data is not None:
|
| 7 |
+
grid_m = routing_data.n_blocks(m, block_m)
|
| 8 |
+
else:
|
| 9 |
+
grid_m = triton.cdiv(m, block_m)
|
| 10 |
+
grid_n = (n + block_n - 1) // block_n
|
| 11 |
+
return grid_m * grid_n
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def compute_block_n(n: int):
|
| 15 |
+
# block_n:
|
| 16 |
+
return max(16, min(128, triton.next_power_of_2(n)))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_block_k(k: int | None, is_persistent: bool, precision_config):
|
| 20 |
+
if k is not None:
|
| 21 |
+
block_k = max(32, min(128, triton.next_power_of_2(k)))
|
| 22 |
+
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
|
| 23 |
+
if is_persistent and has_mx_weight_scale:
|
| 24 |
+
block_k = min(block_k, 128)
|
| 25 |
+
return block_k
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
|
| 29 |
+
device_props = torch.xpu.get_device_properties(0)
|
| 30 |
+
n_sms = device_props.gpu_subslice_count
|
| 31 |
+
split_k = n_sms // grid_size
|
| 32 |
+
if k is not None:
|
| 33 |
+
# avoid split_k for small k
|
| 34 |
+
num_block_k = triton.cdiv(k, block_k)
|
| 35 |
+
split_k = min(split_k, num_block_k // 4)
|
| 36 |
+
split_k = max(split_k, 1)
|
| 37 |
+
return split_k
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compute_num_warps(block_m, block_n):
|
| 41 |
+
return max(block_m * block_n // 4096, 4)
|
torch-ext/triton_kernels/numerics_details/flexpoint.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
| 2 |
from .. import target_info
|
|
|
|
| 3 |
import triton
|
| 4 |
import triton.language as tl
|
| 5 |
|
|
@@ -52,7 +53,7 @@ def rcp_max_finite(dtype):
|
|
| 52 |
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
| 53 |
|
| 54 |
|
| 55 |
-
@
|
| 56 |
def cuda_capability_geq(major, minor):
|
| 57 |
return target_info.cuda_capability_geq(major, minor)
|
| 58 |
|
|
|
|
| 1 |
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
| 2 |
from .. import target_info
|
| 3 |
+
from ..matmul_ogs_details._common import _constexpr_function
|
| 4 |
import triton
|
| 5 |
import triton.language as tl
|
| 6 |
|
|
|
|
| 53 |
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
| 54 |
|
| 55 |
|
| 56 |
+
@_constexpr_function
|
| 57 |
def cuda_capability_geq(major, minor):
|
| 58 |
return target_info.cuda_capability_geq(major, minor)
|
| 59 |
|
torch-ext/triton_kernels/target_info.py
CHANGED
|
@@ -1,54 +1,70 @@
|
|
| 1 |
import torch
|
| 2 |
import triton
|
| 3 |
|
| 4 |
-
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def is_cuda():
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
|
| 11 |
-
return cached_capabilities["is_cuda"]
|
| 12 |
|
| 13 |
|
|
|
|
| 14 |
def is_hip():
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
|
|
|
|
|
|
| 20 |
def is_hip_cdna3():
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
|
| 24 |
-
and target.arch == 'gfx942')
|
| 25 |
-
return cached_capabilities["is_hip_cdna3"]
|
| 26 |
|
| 27 |
|
|
|
|
| 28 |
def is_hip_cdna4():
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
|
| 32 |
-
and target.arch == 'gfx950')
|
| 33 |
-
return cached_capabilities["is_hip_cdna4"]
|
| 34 |
|
| 35 |
|
|
|
|
| 36 |
def cuda_capability_geq(major, minor=0):
|
| 37 |
"""
|
| 38 |
Determines whether we have compute capability >= (major, minor) and
|
| 39 |
returns this as a constexpr boolean. This can be used for guarding
|
| 40 |
inline asm implementations that require a certain compute capability.
|
| 41 |
"""
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
return False
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
cached_capabilities["cuda"] = torch.cuda.get_device_capability()
|
| 47 |
-
else:
|
| 48 |
-
cached_capabilities["cuda"] = (0, 0)
|
| 49 |
-
return cached_capabilities["cuda"] >= (major, minor)
|
| 50 |
|
| 51 |
|
|
|
|
| 52 |
def get_cdna_version():
|
| 53 |
"""
|
| 54 |
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
|
@@ -65,13 +81,18 @@ def get_cdna_version():
|
|
| 65 |
return -1
|
| 66 |
|
| 67 |
|
|
|
|
| 68 |
def has_tma_gather():
|
| 69 |
return cuda_capability_geq(10, 0)
|
| 70 |
|
| 71 |
|
|
|
|
| 72 |
def has_native_mxfp():
|
| 73 |
return cuda_capability_geq(10, 0)
|
| 74 |
|
| 75 |
|
| 76 |
def num_sms():
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import triton
|
| 3 |
|
| 4 |
+
from .matmul_ogs_details._common import _constexpr_function
|
| 5 |
+
from triton.runtime import driver
|
| 6 |
|
| 7 |
+
def current_target():
|
| 8 |
+
try:
|
| 9 |
+
active_driver = driver.active
|
| 10 |
+
except RuntimeError:
|
| 11 |
+
# If there is no active driver, return None
|
| 12 |
+
return None
|
| 13 |
+
return active_driver.get_current_target()
|
| 14 |
|
| 15 |
+
current_target.__triton_builtin__ = True
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@_constexpr_function
|
| 19 |
def is_cuda():
|
| 20 |
+
target = current_target()
|
| 21 |
+
return target is not None and target.backend == "cuda"
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
+
@_constexpr_function
|
| 25 |
def is_hip():
|
| 26 |
+
target = current_target()
|
| 27 |
+
return target is not None and target.backend == "hip"
|
| 28 |
+
|
| 29 |
|
| 30 |
+
@_constexpr_function
|
| 31 |
+
def is_xpu():
|
| 32 |
+
target = current_target()
|
| 33 |
+
return target is not None and target.backend == "xpu"
|
| 34 |
|
| 35 |
+
|
| 36 |
+
@_constexpr_function
|
| 37 |
def is_hip_cdna3():
|
| 38 |
+
target = current_target()
|
| 39 |
+
return target is not None and target.arch == "gfx942"
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
+
@_constexpr_function
|
| 43 |
def is_hip_cdna4():
|
| 44 |
+
target = current_target()
|
| 45 |
+
return target is not None and target.arch == "gfx950"
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
+
@_constexpr_function
|
| 49 |
def cuda_capability_geq(major, minor=0):
|
| 50 |
"""
|
| 51 |
Determines whether we have compute capability >= (major, minor) and
|
| 52 |
returns this as a constexpr boolean. This can be used for guarding
|
| 53 |
inline asm implementations that require a certain compute capability.
|
| 54 |
"""
|
| 55 |
+
"""
|
| 56 |
+
Determines whether we have compute capability >= (major, minor) and
|
| 57 |
+
returns this as a constexpr boolean. This can be used for guarding
|
| 58 |
+
inline asm implementations that require a certain compute capability.
|
| 59 |
+
"""
|
| 60 |
+
target = current_target()
|
| 61 |
+
if target is None or target.backend != "cuda":
|
| 62 |
return False
|
| 63 |
+
assert isinstance(target.arch, int)
|
| 64 |
+
return target.arch >= major * 10 + minor
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
+
@_constexpr_function
|
| 68 |
def get_cdna_version():
|
| 69 |
"""
|
| 70 |
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
|
|
|
| 81 |
return -1
|
| 82 |
|
| 83 |
|
| 84 |
+
@_constexpr_function
|
| 85 |
def has_tma_gather():
|
| 86 |
return cuda_capability_geq(10, 0)
|
| 87 |
|
| 88 |
|
| 89 |
+
@_constexpr_function
|
| 90 |
def has_native_mxfp():
|
| 91 |
return cuda_capability_geq(10, 0)
|
| 92 |
|
| 93 |
|
| 94 |
def num_sms():
|
| 95 |
+
if is_cuda():
|
| 96 |
+
return torch.cuda.get_device_properties(0).multi_processor_count
|
| 97 |
+
if is_xpu():
|
| 98 |
+
return torch.xpu.get_device_properties(0).max_compute_units
|