Implement make_opt_flags function for XPU
#5
by
YangKai0616
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- build/torch-universal/triton_kernels/__init__.py +6 -0
- build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc +0 -0
- build/torch-universal/triton_kernels/_ops.py +2 -2
- build/torch-universal/triton_kernels/compaction.py +0 -0
- build/torch-universal/triton_kernels/compaction_details/_masked_compaction.py +0 -0
- build/torch-universal/triton_kernels/matmul_ogs.py +2 -1
- build/torch-universal/triton_kernels/matmul_ogs_details/_common.py +13 -1
- build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py +6 -5
- build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py +0 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +3 -3
- build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py +80 -1
- build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py +0 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py +41 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +0 -0
- build/torch-universal/triton_kernels/numerics.py +0 -0
- build/torch-universal/triton_kernels/numerics_details/__init__.py +0 -0
- build/torch-universal/triton_kernels/numerics_details/flexpoint.py +2 -1
- build/torch-universal/triton_kernels/numerics_details/mxfp.py +0 -0
- build/torch-universal/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py +0 -0
- build/torch-universal/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py +0 -0
- build/torch-universal/triton_kernels/proton_opts.py +0 -0
- build/torch-universal/triton_kernels/reduction_details/reduce_bitmatrix.py +0 -0
- build/torch-universal/triton_kernels/routing.py +0 -0
- build/torch-universal/triton_kernels/routing_details/_expt_data.py +0 -0
- build/torch-universal/triton_kernels/routing_details/_routing_compute.py +0 -0
- build/torch-universal/triton_kernels/specialize.py +0 -0
- build/torch-universal/triton_kernels/swiglu.py +1 -1
- build/torch-universal/triton_kernels/swiglu_details/_swiglu.py +0 -0
- build/torch-universal/triton_kernels/target_info.py +47 -26
- build/torch-universal/triton_kernels/tensor.py +0 -0
- build/torch-universal/triton_kernels/tensor_details/layout.py +0 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/base.py +0 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/blackwell_scale.py +0 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_scale.py +0 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_value.py +0 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/strided.py +0 -0
- build/torch-universal/triton_kernels/testing.py +0 -0
- build/torch-universal/triton_kernels/topk.py +0 -0
- build/torch-universal/triton_kernels/topk_details/__init__.py +0 -0
- build/torch-universal/triton_kernels/topk_details/_topk_backward.py +0 -0
- build/torch-universal/triton_kernels/topk_details/_topk_forward.py +0 -0
- result +1 -0
- tests/conftest.py +13 -2
- tests/test_matmul.py +13 -12
- torch-ext/triton_kernels/__init__.py +7 -0
- torch-ext/triton_kernels/matmul_ogs.py +2 -1
- torch-ext/triton_kernels/matmul_ogs_details/_common.py +13 -1
- torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py +6 -5
- torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +3 -3
- torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py +80 -1
build/torch-universal/triton_kernels/__init__.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from . import matmul_ogs, tensor_details, numerics_details, tensor, swiglu, routing
|
| 2 |
|
| 3 |
__all__ = ["matmul_ogs" , "tensor_details", "numerics_details", "tensor", "swiglu", "routing"]
|
|
|
|
| 1 |
+
# Make sure to add this in the build folder as this won't build if we put that here
|
| 2 |
+
# docker run --rm \
|
| 3 |
+
# -v $(pwd):/app \
|
| 4 |
+
# -w /app \
|
| 5 |
+
# ghcr.io/huggingface/kernel-builder:main
|
| 6 |
+
|
| 7 |
from . import matmul_ogs, tensor_details, numerics_details, tensor, swiglu, routing
|
| 8 |
|
| 9 |
__all__ = ["matmul_ogs" , "tensor_details", "numerics_details", "tensor", "swiglu", "routing"]
|
build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc
CHANGED
|
Binary files a/build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc and b/build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc differ
|
|
|
build/torch-universal/triton_kernels/_ops.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
-
ops = torch.ops.
|
| 3 |
|
| 4 |
def add_op_namespace_prefix(op_name: str):
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
ops = torch.ops._triton_kernels_a32f88a_dirty
|
| 3 |
|
| 4 |
def add_op_namespace_prefix(op_name: str):
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
+
return f"_triton_kernels_a32f88a_dirty::{op_name}"
|
build/torch-universal/triton_kernels/compaction.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/compaction_details/_masked_compaction.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/matmul_ogs.py
CHANGED
|
@@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
| 602 |
betas = None,
|
| 603 |
gammas = None,
|
| 604 |
round_x = None, round_y = None,
|
|
|
|
| 605 |
):
|
| 606 |
is_input_batched = x.ndim == 3
|
| 607 |
assert x.dtype.itemsize > 1
|
|
@@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
| 641 |
else:
|
| 642 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
| 643 |
batch = i if is_input_batched else 0
|
| 644 |
-
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=
|
| 645 |
w[i].float())
|
| 646 |
if bias is not None:
|
| 647 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
|
|
|
| 602 |
betas = None,
|
| 603 |
gammas = None,
|
| 604 |
round_x = None, round_y = None,
|
| 605 |
+
device: str = "cuda",
|
| 606 |
):
|
| 607 |
is_input_batched = x.ndim == 3
|
| 608 |
assert x.dtype.itemsize > 1
|
|
|
|
| 642 |
else:
|
| 643 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
| 644 |
batch = i if is_input_batched else 0
|
| 645 |
+
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(),
|
| 646 |
w[i].float())
|
| 647 |
if bias is not None:
|
| 648 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
build/torch-universal/triton_kernels/matmul_ogs_details/_common.py
CHANGED
|
@@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor
|
|
| 7 |
# -----------------------------------------------------------------------------
|
| 8 |
# Utilities
|
| 9 |
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
| 14 |
mapping = {
|
| 15 |
tl.float16: "fp16",
|
|
|
|
| 7 |
# -----------------------------------------------------------------------------
|
| 8 |
# Utilities
|
| 9 |
# -----------------------------------------------------------------------------
|
| 10 |
+
try:
|
| 11 |
+
_ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
|
| 12 |
+
_parts = _ver_str.split(".")
|
| 13 |
+
_ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
|
| 14 |
+
except Exception:
|
| 15 |
+
_ver_tuple = (0, 0, 0)
|
| 16 |
|
| 17 |
+
if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"):
|
| 18 |
+
_constexpr_function = triton.constexpr_function
|
| 19 |
+
else:
|
| 20 |
+
_constexpr_function = tl.constexpr_function
|
| 21 |
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@_constexpr_function
|
| 25 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
| 26 |
mapping = {
|
| 27 |
tl.float16: "fp16",
|
build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py
CHANGED
|
@@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
|
|
| 4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
| 6 |
from ..target_info import is_hip as _is_hip
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
# fmt: off
|
| 10 |
-
@
|
| 11 |
def is_hip():
|
| 12 |
return _is_hip()
|
| 13 |
|
| 14 |
|
| 15 |
-
@
|
| 16 |
def cuda_capability_geq(x, y):
|
| 17 |
return _cuda_capability_geq(x, y)
|
| 18 |
|
| 19 |
|
| 20 |
-
@
|
| 21 |
def log2(n):
|
| 22 |
return len(bin(n)) - 3
|
| 23 |
|
| 24 |
|
| 25 |
-
@
|
| 26 |
def _permute_to_end_order(n: int, axis: int):
|
| 27 |
"""
|
| 28 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
|
@@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args):
|
|
| 105 |
return ret
|
| 106 |
|
| 107 |
|
| 108 |
-
@
|
| 109 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
| 110 |
"""
|
| 111 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
|
|
|
| 4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
| 6 |
from ..target_info import is_hip as _is_hip
|
| 7 |
+
from ._common import _constexpr_function
|
| 8 |
|
| 9 |
|
| 10 |
# fmt: off
|
| 11 |
+
@_constexpr_function
|
| 12 |
def is_hip():
|
| 13 |
return _is_hip()
|
| 14 |
|
| 15 |
|
| 16 |
+
@_constexpr_function
|
| 17 |
def cuda_capability_geq(x, y):
|
| 18 |
return _cuda_capability_geq(x, y)
|
| 19 |
|
| 20 |
|
| 21 |
+
@_constexpr_function
|
| 22 |
def log2(n):
|
| 23 |
return len(bin(n)) - 3
|
| 24 |
|
| 25 |
|
| 26 |
+
@_constexpr_function
|
| 27 |
def _permute_to_end_order(n: int, axis: int):
|
| 28 |
"""
|
| 29 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
|
|
|
| 106 |
return ret
|
| 107 |
|
| 108 |
|
| 109 |
+
@_constexpr_function
|
| 110 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
| 111 |
"""
|
| 112 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
CHANGED
|
@@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import (
|
|
| 12 |
compute_scale,
|
| 13 |
)
|
| 14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 15 |
-
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
|
| 16 |
|
| 17 |
|
| 18 |
-
@
|
| 19 |
def cuda_capability_geq(major, minor):
|
| 20 |
return target_info.cuda_capability_geq(major, minor)
|
| 21 |
|
| 22 |
-
@
|
| 23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
| 24 |
if isinstance(tensor_or_desc, tl.tensor):
|
| 25 |
return tensor_or_desc.dtype.element_ty
|
|
|
|
| 12 |
compute_scale,
|
| 13 |
)
|
| 14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 15 |
+
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function
|
| 16 |
|
| 17 |
|
| 18 |
+
@_constexpr_function
|
| 19 |
def cuda_capability_geq(major, minor):
|
| 20 |
return target_info.cuda_capability_geq(major, minor)
|
| 21 |
|
| 22 |
+
@_constexpr_function
|
| 23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
| 24 |
if isinstance(tensor_or_desc, tl.tensor):
|
| 25 |
return tensor_or_desc.dtype.element_ty
|
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py
CHANGED
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
| 4 |
import triton
|
| 5 |
from ..target_info import get_cdna_version
|
| 6 |
import torch
|
| 7 |
-
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
|
@@ -30,6 +30,83 @@ class OptFlags:
|
|
| 30 |
raise ValueError("Not supported")
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def make_default_opt_flags_amd(
|
| 35 |
out_dtype,
|
|
@@ -292,6 +369,8 @@ def make_opt_flags(
|
|
| 292 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
| 293 |
_opt_flags_constraints]
|
| 294 |
backend = triton.runtime.driver.active.get_current_target().backend
|
|
|
|
|
|
|
| 295 |
if backend == "hip":
|
| 296 |
return make_default_opt_flags_amd(*args)
|
| 297 |
if backend == "cuda":
|
|
|
|
| 4 |
import triton
|
| 5 |
from ..target_info import get_cdna_version
|
| 6 |
import torch
|
| 7 |
+
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
|
|
|
| 30 |
raise ValueError("Not supported")
|
| 31 |
|
| 32 |
|
| 33 |
+
def make_default_opt_flags_intel(
|
| 34 |
+
out_dtype,
|
| 35 |
+
lhs_dtype,
|
| 36 |
+
rhs_dtype,
|
| 37 |
+
precision_config,
|
| 38 |
+
m,
|
| 39 |
+
n,
|
| 40 |
+
k,
|
| 41 |
+
routing_data,
|
| 42 |
+
can_use_persistent_tma,
|
| 43 |
+
can_use_fused_scatter,
|
| 44 |
+
enforce_bitwise_invariance,
|
| 45 |
+
epilogue_effective_itemsize,
|
| 46 |
+
constraints,
|
| 47 |
+
):
|
| 48 |
+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
|
| 49 |
+
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
|
| 50 |
+
# tokens per expert
|
| 51 |
+
if routing_data is None:
|
| 52 |
+
tokens_per_expt = m
|
| 53 |
+
elif routing_data.expected_tokens_per_expt is None:
|
| 54 |
+
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
|
| 55 |
+
else:
|
| 56 |
+
tokens_per_expt = routing_data.expected_tokens_per_expt
|
| 57 |
+
# pid swizzling
|
| 58 |
+
group_m = 8
|
| 59 |
+
xcd_swizzle = 1
|
| 60 |
+
# block_m
|
| 61 |
+
if constraints.get("block_m", None):
|
| 62 |
+
block_m = constraints["block_m"]
|
| 63 |
+
elif enforce_bitwise_invariance:
|
| 64 |
+
block_m = 128
|
| 65 |
+
else:
|
| 66 |
+
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
|
| 67 |
+
# block n
|
| 68 |
+
block_n = opt_flags_intel.compute_block_n(n)
|
| 69 |
+
# is_persistent
|
| 70 |
+
is_persistent = constraints.get("is_persistent", False)
|
| 71 |
+
# block k
|
| 72 |
+
if constraints.get("block_k", None) is not None:
|
| 73 |
+
block_k = constraints["block_k"]
|
| 74 |
+
else:
|
| 75 |
+
block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config)
|
| 76 |
+
# split_k
|
| 77 |
+
if constraints.get("split_k", None) is not None:
|
| 78 |
+
split_k = constraints["split_k"]
|
| 79 |
+
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
|
| 80 |
+
split_k = 1
|
| 81 |
+
else:
|
| 82 |
+
estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n)
|
| 83 |
+
split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size)
|
| 84 |
+
|
| 85 |
+
epilogue_subtile = constraints.get('epilogue_subtile', None)
|
| 86 |
+
if epilogue_subtile is None:
|
| 87 |
+
epilogue_subtile = 1
|
| 88 |
+
|
| 89 |
+
ret = OptFlags(
|
| 90 |
+
block_m=block_m,
|
| 91 |
+
block_n=block_n,
|
| 92 |
+
block_k=block_k,
|
| 93 |
+
num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
|
| 94 |
+
num_stages=constraints.get("num_stages", 2),
|
| 95 |
+
fused_scatter=constraints.get('fused_scatter', False),
|
| 96 |
+
group_m=group_m,
|
| 97 |
+
xcd_swizzle=xcd_swizzle,
|
| 98 |
+
w_cache_modifier=None,
|
| 99 |
+
split_k=split_k,
|
| 100 |
+
is_persistent=is_persistent,
|
| 101 |
+
epilogue_subtile=epilogue_subtile,
|
| 102 |
+
arch=None,
|
| 103 |
+
target_kernel_kwargs=dict(),
|
| 104 |
+
idle_sms=0,
|
| 105 |
+
)
|
| 106 |
+
# check constraints
|
| 107 |
+
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
|
| 108 |
+
return ret
|
| 109 |
+
|
| 110 |
|
| 111 |
def make_default_opt_flags_amd(
|
| 112 |
out_dtype,
|
|
|
|
| 369 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
| 370 |
_opt_flags_constraints]
|
| 371 |
backend = triton.runtime.driver.active.get_current_target().backend
|
| 372 |
+
if backend == "xpu":
|
| 373 |
+
return make_default_opt_flags_intel(*args)
|
| 374 |
if backend == "hip":
|
| 375 |
return make_default_opt_flags_amd(*args)
|
| 376 |
if backend == "cuda":
|
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compute_grid_size(routing_data, m, n, block_m, block_n):
|
| 6 |
+
if routing_data is not None:
|
| 7 |
+
grid_m = routing_data.n_blocks(m, block_m)
|
| 8 |
+
else:
|
| 9 |
+
grid_m = triton.cdiv(m, block_m)
|
| 10 |
+
grid_n = (n + block_n - 1) // block_n
|
| 11 |
+
return grid_m * grid_n
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def compute_block_n(n: int):
|
| 15 |
+
# block_n:
|
| 16 |
+
return max(16, min(128, triton.next_power_of_2(n)))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_block_k(k: int | None, is_persistent: bool, precision_config):
|
| 20 |
+
if k is not None:
|
| 21 |
+
block_k = max(32, min(128, triton.next_power_of_2(k)))
|
| 22 |
+
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
|
| 23 |
+
if is_persistent and has_mx_weight_scale:
|
| 24 |
+
block_k = min(block_k, 128)
|
| 25 |
+
return block_k
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
|
| 29 |
+
device_props = torch.xpu.get_device_properties(0)
|
| 30 |
+
n_sms = device_props.gpu_subslice_count
|
| 31 |
+
split_k = n_sms // grid_size
|
| 32 |
+
if k is not None:
|
| 33 |
+
# avoid split_k for small k
|
| 34 |
+
num_block_k = triton.cdiv(k, block_k)
|
| 35 |
+
split_k = min(split_k, num_block_k // 4)
|
| 36 |
+
split_k = max(split_k, 1)
|
| 37 |
+
return split_k
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compute_num_warps(block_m, block_n):
|
| 41 |
+
return max(block_m * block_n // 4096, 4)
|
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/numerics.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/numerics_details/__init__.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/numerics_details/flexpoint.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
| 2 |
from .. import target_info
|
|
|
|
| 3 |
import triton
|
| 4 |
import triton.language as tl
|
| 5 |
|
|
@@ -52,7 +53,7 @@ def rcp_max_finite(dtype):
|
|
| 52 |
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
| 53 |
|
| 54 |
|
| 55 |
-
@
|
| 56 |
def cuda_capability_geq(major, minor):
|
| 57 |
return target_info.cuda_capability_geq(major, minor)
|
| 58 |
|
|
|
|
| 1 |
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
| 2 |
from .. import target_info
|
| 3 |
+
from ..matmul_ogs_details._common import _constexpr_function
|
| 4 |
import triton
|
| 5 |
import triton.language as tl
|
| 6 |
|
|
|
|
| 53 |
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
| 54 |
|
| 55 |
|
| 56 |
+
@_constexpr_function
|
| 57 |
def cuda_capability_geq(major, minor):
|
| 58 |
return target_info.cuda_capability_geq(major, minor)
|
| 59 |
|
build/torch-universal/triton_kernels/numerics_details/mxfp.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/proton_opts.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/reduction_details/reduce_bitmatrix.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/routing.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/routing_details/_expt_data.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/routing_details/_routing_compute.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/specialize.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/swiglu.py
CHANGED
|
@@ -35,7 +35,7 @@ class SwiGLU(torch.autograd.Function):
|
|
| 35 |
# optimization hyperparameters
|
| 36 |
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
|
| 37 |
num_warps = 4
|
| 38 |
-
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
|
| 39 |
# launch semi-persistent kernel
|
| 40 |
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
|
| 41 |
num_sms = target_info.num_sms()
|
|
|
|
| 35 |
# optimization hyperparameters
|
| 36 |
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
|
| 37 |
num_warps = 4
|
| 38 |
+
kwargs = {'maxnreg': 64} if not target_info.is_hip() and not target_info.is_xpu() else {}
|
| 39 |
# launch semi-persistent kernel
|
| 40 |
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
|
| 41 |
num_sms = target_info.num_sms()
|
build/torch-universal/triton_kernels/swiglu_details/_swiglu.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/target_info.py
CHANGED
|
@@ -1,54 +1,70 @@
|
|
| 1 |
import torch
|
| 2 |
import triton
|
| 3 |
|
| 4 |
-
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def is_cuda():
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
|
| 11 |
-
return cached_capabilities["is_cuda"]
|
| 12 |
|
| 13 |
|
|
|
|
| 14 |
def is_hip():
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
|
|
|
|
|
|
| 20 |
def is_hip_cdna3():
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
|
| 24 |
-
and target.arch == 'gfx942')
|
| 25 |
-
return cached_capabilities["is_hip_cdna3"]
|
| 26 |
|
| 27 |
|
|
|
|
| 28 |
def is_hip_cdna4():
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
|
| 32 |
-
and target.arch == 'gfx950')
|
| 33 |
-
return cached_capabilities["is_hip_cdna4"]
|
| 34 |
|
| 35 |
|
|
|
|
| 36 |
def cuda_capability_geq(major, minor=0):
|
| 37 |
"""
|
| 38 |
Determines whether we have compute capability >= (major, minor) and
|
| 39 |
returns this as a constexpr boolean. This can be used for guarding
|
| 40 |
inline asm implementations that require a certain compute capability.
|
| 41 |
"""
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
return False
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
cached_capabilities["cuda"] = torch.cuda.get_device_capability()
|
| 47 |
-
else:
|
| 48 |
-
cached_capabilities["cuda"] = (0, 0)
|
| 49 |
-
return cached_capabilities["cuda"] >= (major, minor)
|
| 50 |
|
| 51 |
|
|
|
|
| 52 |
def get_cdna_version():
|
| 53 |
"""
|
| 54 |
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
|
@@ -65,13 +81,18 @@ def get_cdna_version():
|
|
| 65 |
return -1
|
| 66 |
|
| 67 |
|
|
|
|
| 68 |
def has_tma_gather():
|
| 69 |
return cuda_capability_geq(10, 0)
|
| 70 |
|
| 71 |
|
|
|
|
| 72 |
def has_native_mxfp():
|
| 73 |
return cuda_capability_geq(10, 0)
|
| 74 |
|
| 75 |
|
| 76 |
def num_sms():
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import triton
|
| 3 |
|
| 4 |
+
from .matmul_ogs_details._common import _constexpr_function
|
| 5 |
+
from triton.runtime import driver
|
| 6 |
|
| 7 |
+
def current_target():
|
| 8 |
+
try:
|
| 9 |
+
active_driver = driver.active
|
| 10 |
+
except RuntimeError:
|
| 11 |
+
# If there is no active driver, return None
|
| 12 |
+
return None
|
| 13 |
+
return active_driver.get_current_target()
|
| 14 |
|
| 15 |
+
current_target.__triton_builtin__ = True
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@_constexpr_function
|
| 19 |
def is_cuda():
|
| 20 |
+
target = current_target()
|
| 21 |
+
return target is not None and target.backend == "cuda"
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
+
@_constexpr_function
|
| 25 |
def is_hip():
|
| 26 |
+
target = current_target()
|
| 27 |
+
return target is not None and target.backend == "hip"
|
| 28 |
+
|
| 29 |
|
| 30 |
+
@_constexpr_function
|
| 31 |
+
def is_xpu():
|
| 32 |
+
target = current_target()
|
| 33 |
+
return target is not None and target.backend == "xpu"
|
| 34 |
|
| 35 |
+
|
| 36 |
+
@_constexpr_function
|
| 37 |
def is_hip_cdna3():
|
| 38 |
+
target = current_target()
|
| 39 |
+
return target is not None and target.arch == "gfx942"
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
+
@_constexpr_function
|
| 43 |
def is_hip_cdna4():
|
| 44 |
+
target = current_target()
|
| 45 |
+
return target is not None and target.arch == "gfx950"
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
+
@_constexpr_function
|
| 49 |
def cuda_capability_geq(major, minor=0):
|
| 50 |
"""
|
| 51 |
Determines whether we have compute capability >= (major, minor) and
|
| 52 |
returns this as a constexpr boolean. This can be used for guarding
|
| 53 |
inline asm implementations that require a certain compute capability.
|
| 54 |
"""
|
| 55 |
+
"""
|
| 56 |
+
Determines whether we have compute capability >= (major, minor) and
|
| 57 |
+
returns this as a constexpr boolean. This can be used for guarding
|
| 58 |
+
inline asm implementations that require a certain compute capability.
|
| 59 |
+
"""
|
| 60 |
+
target = current_target()
|
| 61 |
+
if target is None or target.backend != "cuda":
|
| 62 |
return False
|
| 63 |
+
assert isinstance(target.arch, int)
|
| 64 |
+
return target.arch >= major * 10 + minor
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
+
@_constexpr_function
|
| 68 |
def get_cdna_version():
|
| 69 |
"""
|
| 70 |
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
|
|
|
| 81 |
return -1
|
| 82 |
|
| 83 |
|
| 84 |
+
@_constexpr_function
|
| 85 |
def has_tma_gather():
|
| 86 |
return cuda_capability_geq(10, 0)
|
| 87 |
|
| 88 |
|
| 89 |
+
@_constexpr_function
|
| 90 |
def has_native_mxfp():
|
| 91 |
return cuda_capability_geq(10, 0)
|
| 92 |
|
| 93 |
|
| 94 |
def num_sms():
|
| 95 |
+
if is_cuda():
|
| 96 |
+
return torch.cuda.get_device_properties(0).multi_processor_count
|
| 97 |
+
if is_xpu():
|
| 98 |
+
return torch.xpu.get_device_properties(0).max_compute_units
|
build/torch-universal/triton_kernels/tensor.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/tensor_details/layout.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/tensor_details/layout_details/base.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/tensor_details/layout_details/blackwell_scale.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_scale.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_value.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/tensor_details/layout_details/strided.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/testing.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/topk.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/topk_details/__init__.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/topk_details/_topk_backward.py
CHANGED
|
File without changes
|
build/torch-universal/triton_kernels/topk_details/_topk_forward.py
CHANGED
|
File without changes
|
result
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/nix/store/jkq2iihqbwik7pdn215w2ysgzhsgj3sc-torch-ext-bundle
|
tests/conftest.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import pytest
|
| 2 |
-
|
| 3 |
|
| 4 |
def pytest_addoption(parser):
|
| 5 |
parser.addoption("--device", action="store", default="cuda")
|
|
@@ -12,8 +12,19 @@ def device(request):
|
|
| 12 |
|
| 13 |
@pytest.fixture
|
| 14 |
def fresh_knobs(monkeypatch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from triton._internal_testing import _fresh_knobs_impl
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
try:
|
| 18 |
yield fresh_function()
|
| 19 |
finally:
|
|
|
|
| 1 |
import pytest
|
| 2 |
+
import triton
|
| 3 |
|
| 4 |
def pytest_addoption(parser):
|
| 5 |
parser.addoption("--device", action="store", default="cuda")
|
|
|
|
| 12 |
|
| 13 |
@pytest.fixture
|
| 14 |
def fresh_knobs(monkeypatch):
|
| 15 |
+
try:
|
| 16 |
+
_ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
|
| 17 |
+
_parts = _ver_str.split(".")
|
| 18 |
+
_ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
|
| 19 |
+
except Exception:
|
| 20 |
+
_ver_tuple = (0, 0, 0)
|
| 21 |
+
|
| 22 |
from triton._internal_testing import _fresh_knobs_impl
|
| 23 |
+
if _ver_tuple > (3, 4, 0):
|
| 24 |
+
fresh_function, reset_function = _fresh_knobs_impl()
|
| 25 |
+
else:
|
| 26 |
+
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
|
| 27 |
+
|
| 28 |
try:
|
| 29 |
yield fresh_function()
|
| 30 |
finally:
|
tests/test_matmul.py
CHANGED
|
@@ -20,7 +20,7 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
|
|
| 20 |
# testing utilities
|
| 21 |
from triton_kernels.testing import assert_close, compute_actual_scale
|
| 22 |
# target-specific utilities
|
| 23 |
-
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
|
| 24 |
|
| 25 |
# ---------------
|
| 26 |
# initialize data
|
|
@@ -70,7 +70,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
|
|
| 70 |
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
|
| 71 |
gs0 = None
|
| 72 |
gs1 = None
|
| 73 |
-
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
|
| 74 |
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
|
| 75 |
return x, w, bias, gs0, gs1
|
| 76 |
|
|
@@ -291,14 +291,15 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
| 291 |
if hbm_swizzling:
|
| 292 |
if is_hip():
|
| 293 |
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
|
| 294 |
-
if
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
if
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
| 302 |
|
| 303 |
# launch metadata for batched / mx types may not work yet.
|
| 304 |
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
|
|
@@ -306,7 +307,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
| 306 |
torch.manual_seed(0)
|
| 307 |
|
| 308 |
block_k = None
|
| 309 |
-
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
|
| 310 |
# Override block_k for testing correctness. The default is temporarily 128 for
|
| 311 |
# performance reasons which doesn't work with persistent matmul.
|
| 312 |
# TODO: revisit when Triton is better for H100 + MXFP4
|
|
@@ -462,7 +463,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
| 462 |
|
| 463 |
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
|
| 464 |
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
|
| 465 |
-
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
|
| 466 |
scale = lambda val, scal: val if scal is None else val / scal
|
| 467 |
if n_expt_shards > 1:
|
| 468 |
if do_scatter:
|
|
|
|
| 20 |
# testing utilities
|
| 21 |
from triton_kernels.testing import assert_close, compute_actual_scale
|
| 22 |
# target-specific utilities
|
| 23 |
+
from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4
|
| 24 |
|
| 25 |
# ---------------
|
| 26 |
# initialize data
|
|
|
|
| 70 |
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
|
| 71 |
gs0 = None
|
| 72 |
gs1 = None
|
| 73 |
+
if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
|
| 74 |
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
|
| 75 |
return x, w, bias, gs0, gs1
|
| 76 |
|
|
|
|
| 291 |
if hbm_swizzling:
|
| 292 |
if is_hip():
|
| 293 |
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
|
| 294 |
+
if is_cuda():
|
| 295 |
+
if torch.cuda.get_device_capability()[0] < 9:
|
| 296 |
+
pytest.skip("NYI. Ampere swizzling.")
|
| 297 |
+
if torch.cuda.get_device_capability()[0] < 10:
|
| 298 |
+
if "mxfloat4" not in weight_dtype_str:
|
| 299 |
+
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
|
| 300 |
+
if k % 64 != 0 or n % 64 != 0:
|
| 301 |
+
# Automatic padding not implemented for Hopper swizzle
|
| 302 |
+
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
|
| 303 |
|
| 304 |
# launch metadata for batched / mx types may not work yet.
|
| 305 |
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
|
|
|
|
| 307 |
torch.manual_seed(0)
|
| 308 |
|
| 309 |
block_k = None
|
| 310 |
+
if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
|
| 311 |
# Override block_k for testing correctness. The default is temporarily 128 for
|
| 312 |
# performance reasons which doesn't work with persistent matmul.
|
| 313 |
# TODO: revisit when Triton is better for H100 + MXFP4
|
|
|
|
| 463 |
|
| 464 |
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
|
| 465 |
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
|
| 466 |
+
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, device=device)
|
| 467 |
scale = lambda val, scal: val if scal is None else val / scal
|
| 468 |
if n_expt_shards > 1:
|
| 469 |
if do_scatter:
|
torch-ext/triton_kernels/__init__.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
# Make sure to add this in the build folder as this won't build if we put that here
|
|
|
|
| 2 |
# from . import matmul_ogs, tensor_details, numerics_details, tensor, swiglu, routing
|
| 3 |
|
| 4 |
# __all__ = ["matmul_ogs" , "tensor_details", "numerics_details", "tensor", "swiglu", "routing"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Make sure to add this in the build folder as this won't build if we put that here
|
| 2 |
+
|
| 3 |
# from . import matmul_ogs, tensor_details, numerics_details, tensor, swiglu, routing
|
| 4 |
|
| 5 |
# __all__ = ["matmul_ogs" , "tensor_details", "numerics_details", "tensor", "swiglu", "routing"]
|
| 6 |
+
|
| 7 |
+
# Then, run the following code to build the kernels:
|
| 8 |
+
# docker run --rm \
|
| 9 |
+
# -v $(pwd):/app \
|
| 10 |
+
# -w /app \
|
| 11 |
+
# ghcr.io/huggingface/kernel-builder:main
|
torch-ext/triton_kernels/matmul_ogs.py
CHANGED
|
@@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
| 602 |
betas = None,
|
| 603 |
gammas = None,
|
| 604 |
round_x = None, round_y = None,
|
|
|
|
| 605 |
):
|
| 606 |
is_input_batched = x.ndim == 3
|
| 607 |
assert x.dtype.itemsize > 1
|
|
@@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
| 641 |
else:
|
| 642 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
| 643 |
batch = i if is_input_batched else 0
|
| 644 |
-
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=
|
| 645 |
w[i].float())
|
| 646 |
if bias is not None:
|
| 647 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
|
|
|
| 602 |
betas = None,
|
| 603 |
gammas = None,
|
| 604 |
round_x = None, round_y = None,
|
| 605 |
+
device: str = "cuda",
|
| 606 |
):
|
| 607 |
is_input_batched = x.ndim == 3
|
| 608 |
assert x.dtype.itemsize > 1
|
|
|
|
| 642 |
else:
|
| 643 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
| 644 |
batch = i if is_input_batched else 0
|
| 645 |
+
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(),
|
| 646 |
w[i].float())
|
| 647 |
if bias is not None:
|
| 648 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
torch-ext/triton_kernels/matmul_ogs_details/_common.py
CHANGED
|
@@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor
|
|
| 7 |
# -----------------------------------------------------------------------------
|
| 8 |
# Utilities
|
| 9 |
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
| 14 |
mapping = {
|
| 15 |
tl.float16: "fp16",
|
|
|
|
| 7 |
# -----------------------------------------------------------------------------
|
| 8 |
# Utilities
|
| 9 |
# -----------------------------------------------------------------------------
|
| 10 |
+
try:
|
| 11 |
+
_ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
|
| 12 |
+
_parts = _ver_str.split(".")
|
| 13 |
+
_ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
|
| 14 |
+
except Exception:
|
| 15 |
+
_ver_tuple = (0, 0, 0)
|
| 16 |
|
| 17 |
+
if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"):
|
| 18 |
+
_constexpr_function = triton.constexpr_function
|
| 19 |
+
else:
|
| 20 |
+
_constexpr_function = tl.constexpr_function
|
| 21 |
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@_constexpr_function
|
| 25 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
| 26 |
mapping = {
|
| 27 |
tl.float16: "fp16",
|
torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py
CHANGED
|
@@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
|
|
| 4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
| 6 |
from ..target_info import is_hip as _is_hip
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
# fmt: off
|
| 10 |
-
@
|
| 11 |
def is_hip():
|
| 12 |
return _is_hip()
|
| 13 |
|
| 14 |
|
| 15 |
-
@
|
| 16 |
def cuda_capability_geq(x, y):
|
| 17 |
return _cuda_capability_geq(x, y)
|
| 18 |
|
| 19 |
|
| 20 |
-
@
|
| 21 |
def log2(n):
|
| 22 |
return len(bin(n)) - 3
|
| 23 |
|
| 24 |
|
| 25 |
-
@
|
| 26 |
def _permute_to_end_order(n: int, axis: int):
|
| 27 |
"""
|
| 28 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
|
@@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args):
|
|
| 105 |
return ret
|
| 106 |
|
| 107 |
|
| 108 |
-
@
|
| 109 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
| 110 |
"""
|
| 111 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
|
|
|
| 4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
| 6 |
from ..target_info import is_hip as _is_hip
|
| 7 |
+
from ._common import _constexpr_function
|
| 8 |
|
| 9 |
|
| 10 |
# fmt: off
|
| 11 |
+
@_constexpr_function
|
| 12 |
def is_hip():
|
| 13 |
return _is_hip()
|
| 14 |
|
| 15 |
|
| 16 |
+
@_constexpr_function
|
| 17 |
def cuda_capability_geq(x, y):
|
| 18 |
return _cuda_capability_geq(x, y)
|
| 19 |
|
| 20 |
|
| 21 |
+
@_constexpr_function
|
| 22 |
def log2(n):
|
| 23 |
return len(bin(n)) - 3
|
| 24 |
|
| 25 |
|
| 26 |
+
@_constexpr_function
|
| 27 |
def _permute_to_end_order(n: int, axis: int):
|
| 28 |
"""
|
| 29 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
|
|
|
| 106 |
return ret
|
| 107 |
|
| 108 |
|
| 109 |
+
@_constexpr_function
|
| 110 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
| 111 |
"""
|
| 112 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
CHANGED
|
@@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import (
|
|
| 12 |
compute_scale,
|
| 13 |
)
|
| 14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 15 |
-
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
|
| 16 |
|
| 17 |
|
| 18 |
-
@
|
| 19 |
def cuda_capability_geq(major, minor):
|
| 20 |
return target_info.cuda_capability_geq(major, minor)
|
| 21 |
|
| 22 |
-
@
|
| 23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
| 24 |
if isinstance(tensor_or_desc, tl.tensor):
|
| 25 |
return tensor_or_desc.dtype.element_ty
|
|
|
|
| 12 |
compute_scale,
|
| 13 |
)
|
| 14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 15 |
+
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function
|
| 16 |
|
| 17 |
|
| 18 |
+
@_constexpr_function
|
| 19 |
def cuda_capability_geq(major, minor):
|
| 20 |
return target_info.cuda_capability_geq(major, minor)
|
| 21 |
|
| 22 |
+
@_constexpr_function
|
| 23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
| 24 |
if isinstance(tensor_or_desc, tl.tensor):
|
| 25 |
return tensor_or_desc.dtype.element_ty
|
torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py
CHANGED
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
| 4 |
import triton
|
| 5 |
from ..target_info import get_cdna_version
|
| 6 |
import torch
|
| 7 |
-
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
|
@@ -30,6 +30,83 @@ class OptFlags:
|
|
| 30 |
raise ValueError("Not supported")
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def make_default_opt_flags_amd(
|
| 35 |
out_dtype,
|
|
@@ -292,6 +369,8 @@ def make_opt_flags(
|
|
| 292 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
| 293 |
_opt_flags_constraints]
|
| 294 |
backend = triton.runtime.driver.active.get_current_target().backend
|
|
|
|
|
|
|
| 295 |
if backend == "hip":
|
| 296 |
return make_default_opt_flags_amd(*args)
|
| 297 |
if backend == "cuda":
|
|
|
|
| 4 |
import triton
|
| 5 |
from ..target_info import get_cdna_version
|
| 6 |
import torch
|
| 7 |
+
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
|
|
|
| 30 |
raise ValueError("Not supported")
|
| 31 |
|
| 32 |
|
| 33 |
+
def make_default_opt_flags_intel(
|
| 34 |
+
out_dtype,
|
| 35 |
+
lhs_dtype,
|
| 36 |
+
rhs_dtype,
|
| 37 |
+
precision_config,
|
| 38 |
+
m,
|
| 39 |
+
n,
|
| 40 |
+
k,
|
| 41 |
+
routing_data,
|
| 42 |
+
can_use_persistent_tma,
|
| 43 |
+
can_use_fused_scatter,
|
| 44 |
+
enforce_bitwise_invariance,
|
| 45 |
+
epilogue_effective_itemsize,
|
| 46 |
+
constraints,
|
| 47 |
+
):
|
| 48 |
+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
|
| 49 |
+
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
|
| 50 |
+
# tokens per expert
|
| 51 |
+
if routing_data is None:
|
| 52 |
+
tokens_per_expt = m
|
| 53 |
+
elif routing_data.expected_tokens_per_expt is None:
|
| 54 |
+
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
|
| 55 |
+
else:
|
| 56 |
+
tokens_per_expt = routing_data.expected_tokens_per_expt
|
| 57 |
+
# pid swizzling
|
| 58 |
+
group_m = 8
|
| 59 |
+
xcd_swizzle = 1
|
| 60 |
+
# block_m
|
| 61 |
+
if constraints.get("block_m", None):
|
| 62 |
+
block_m = constraints["block_m"]
|
| 63 |
+
elif enforce_bitwise_invariance:
|
| 64 |
+
block_m = 128
|
| 65 |
+
else:
|
| 66 |
+
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
|
| 67 |
+
# block n
|
| 68 |
+
block_n = opt_flags_intel.compute_block_n(n)
|
| 69 |
+
# is_persistent
|
| 70 |
+
is_persistent = constraints.get("is_persistent", False)
|
| 71 |
+
# block k
|
| 72 |
+
if constraints.get("block_k", None) is not None:
|
| 73 |
+
block_k = constraints["block_k"]
|
| 74 |
+
else:
|
| 75 |
+
block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config)
|
| 76 |
+
# split_k
|
| 77 |
+
if constraints.get("split_k", None) is not None:
|
| 78 |
+
split_k = constraints["split_k"]
|
| 79 |
+
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
|
| 80 |
+
split_k = 1
|
| 81 |
+
else:
|
| 82 |
+
estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n)
|
| 83 |
+
split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size)
|
| 84 |
+
|
| 85 |
+
epilogue_subtile = constraints.get('epilogue_subtile', None)
|
| 86 |
+
if epilogue_subtile is None:
|
| 87 |
+
epilogue_subtile = 1
|
| 88 |
+
|
| 89 |
+
ret = OptFlags(
|
| 90 |
+
block_m=block_m,
|
| 91 |
+
block_n=block_n,
|
| 92 |
+
block_k=block_k,
|
| 93 |
+
num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
|
| 94 |
+
num_stages=constraints.get("num_stages", 2),
|
| 95 |
+
fused_scatter=constraints.get('fused_scatter', False),
|
| 96 |
+
group_m=group_m,
|
| 97 |
+
xcd_swizzle=xcd_swizzle,
|
| 98 |
+
w_cache_modifier=None,
|
| 99 |
+
split_k=split_k,
|
| 100 |
+
is_persistent=is_persistent,
|
| 101 |
+
epilogue_subtile=epilogue_subtile,
|
| 102 |
+
arch=None,
|
| 103 |
+
target_kernel_kwargs=dict(),
|
| 104 |
+
idle_sms=0,
|
| 105 |
+
)
|
| 106 |
+
# check constraints
|
| 107 |
+
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
|
| 108 |
+
return ret
|
| 109 |
+
|
| 110 |
|
| 111 |
def make_default_opt_flags_amd(
|
| 112 |
out_dtype,
|
|
|
|
| 369 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
| 370 |
_opt_flags_constraints]
|
| 371 |
backend = triton.runtime.driver.active.get_current_target().backend
|
| 372 |
+
if backend == "xpu":
|
| 373 |
+
return make_default_opt_flags_intel(*args)
|
| 374 |
if backend == "hip":
|
| 375 |
return make_default_opt_flags_amd(*args)
|
| 376 |
if backend == "cuda":
|