drbh
commited on
Commit
·
09e15a7
1
Parent(s):
3bdb4b8
fix: add quickstart and avoid autotune when no cuda
Browse files- README.md +54 -30
- readme_example.py +51 -0
- torch-ext/megablocks/backend/kernels.py +14 -0
README.md
CHANGED
|
@@ -4,39 +4,63 @@ tags:
|
|
| 4 |
- kernel
|
| 5 |
---
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
```bash
|
| 10 |
-
|
| 11 |
```
|
| 12 |
|
| 13 |
-
expected output:
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
```
|
| 16 |
-
============== test session starts ===============
|
| 17 |
-
platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.5.0
|
| 18 |
-
rootdir: /home/ubuntu/Projects/megablocks-moe
|
| 19 |
-
plugins: hypothesis-6.130.12
|
| 20 |
-
collecting 43 items world_size=1
|
| 21 |
-
collected 387 items
|
| 22 |
-
|
| 23 |
-
tests/layers/moe_test.py ...........................................
|
| 24 |
-
tests/ops/binned_gather_test.py .....................
|
| 25 |
-
tests/ops/binned_scatter_test.py .....................
|
| 26 |
-
tests/ops/cumsum_test.py ................................
|
| 27 |
-
tests/ops/histogram_test.py ......................................................
|
| 28 |
-
tests/ops/padded_gather_test.py ......................................
|
| 29 |
-
tests/ops/padded_scatter_test.py ......................................................
|
| 30 |
-
tests/ops/replicate_test.py ..................................................................................
|
| 31 |
-
tests/ops/sort_test.py ..................
|
| 32 |
-
tests/ops/topology_test.py ....................
|
| 33 |
-
tests/test_mb_moe.py megablocks_moe module imported successfully.
|
| 34 |
-
Available functions: ['Arguments', 'MLP', 'MoE', 'ParallelDroplessMLP', 'ParallelMLP', 'SparseGLU', 'SparseMLP', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_megablocks_a4f6452_dirty', '_ops', 'argsort', 'backend', 'cumsum', 'dMoE', 'exclusive_cumsum', 'get_load_balancing_loss', 'grouped_gemm_util', 'histogram', 'inclusive_cumsum', 'indices', 'layers', 'ops', 'replicate_backward', 'replicate_forward', 'sort', 'torch']
|
| 35 |
-
.cumsum output: tensor([0, 1, 3, 6], device='cuda:0', dtype=torch.int16)
|
| 36 |
-
...
|
| 37 |
-
|
| 38 |
-
================ warnings summary ================
|
| 39 |
-
...
|
| 40 |
-
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
| 41 |
-
======= 387 passed, 18 warnings in 54.63s ========
|
| 42 |
-
```
|
|
|
|
| 4 |
- kernel
|
| 5 |
---
|
| 6 |
|
| 7 |
+
## Quickstart
|
| 8 |
|
| 9 |
```bash
|
| 10 |
+
uv run https://huggingface.co/kernels-community/megablocks/raw/main/readme_example.py
|
| 11 |
```
|
| 12 |
|
|
|
|
| 13 |
|
| 14 |
+
```python
|
| 15 |
+
# /// script
|
| 16 |
+
# requires-python = "==3.10"
|
| 17 |
+
# dependencies = [
|
| 18 |
+
# "numpy",
|
| 19 |
+
# "kernels",
|
| 20 |
+
# "torch"
|
| 21 |
+
# ]
|
| 22 |
+
# ///
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from collections import namedtuple
|
| 26 |
+
|
| 27 |
+
from kernels import get_kernel
|
| 28 |
+
|
| 29 |
+
# Make reproducible
|
| 30 |
+
torch.manual_seed(42)
|
| 31 |
+
torch.cuda.manual_seed(42)
|
| 32 |
+
|
| 33 |
+
# Download optimized kernels from the Hugging Face hub
|
| 34 |
+
megablocks = get_kernel("kernels-community/megablocks")
|
| 35 |
+
print("MegaBlocks kernel downloaded successfully.")
|
| 36 |
+
|
| 37 |
+
model = megablocks.layers.MegaBlocksMoeMLP()
|
| 38 |
+
model.experts = namedtuple("Experts", ["gate_up_proj", "gate_down_proj", "down_proj", "hidden_size"])
|
| 39 |
+
print("MegaBlocksMoeMLP instance created successfully.")
|
| 40 |
+
|
| 41 |
+
# Config
|
| 42 |
+
ne, hs, isz = 128, 1152, 3072
|
| 43 |
+
|
| 44 |
+
# Router with proper initialization
|
| 45 |
+
model.router = torch.nn.Linear(hs, ne, device="cuda")
|
| 46 |
+
torch.nn.init.kaiming_uniform_(model.router.weight)
|
| 47 |
+
|
| 48 |
+
# Expert layers with realistic weights
|
| 49 |
+
e = model.experts
|
| 50 |
+
e.gate_up_proj = torch.nn.Parameter(torch.randn(ne, hs, isz, device="cuda") * 0.02)
|
| 51 |
+
e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda"))
|
| 52 |
+
e.down_proj = torch.nn.Parameter(torch.randn(ne, 1536, hs, device="cuda") * 0.02)
|
| 53 |
+
e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda"))
|
| 54 |
+
e.hidden_size = hs
|
| 55 |
+
print("Expert layers initialized successfully.")
|
| 56 |
+
|
| 57 |
+
# Test with normalized input
|
| 58 |
+
x = torch.randn(1, 1, hs, device="cuda") * 0.1
|
| 59 |
+
output, expert_weights = model(x)
|
| 60 |
+
print("Model forward pass completed successfully.")
|
| 61 |
+
|
| 62 |
+
print(f"Output shape: {output.shape}")
|
| 63 |
+
print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
|
| 64 |
+
print(f"Output: {output.flatten()[:10]}")
|
| 65 |
+
print(f"Expert weights sum: {expert_weights.sum():.3f}")
|
| 66 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readme_example.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = "==3.10"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "numpy",
|
| 5 |
+
# "kernels",
|
| 6 |
+
# "torch"
|
| 7 |
+
# ]
|
| 8 |
+
# ///
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from collections import namedtuple
|
| 12 |
+
|
| 13 |
+
from kernels import get_kernel
|
| 14 |
+
|
| 15 |
+
# Make reproducible
|
| 16 |
+
torch.manual_seed(42)
|
| 17 |
+
torch.cuda.manual_seed(42)
|
| 18 |
+
|
| 19 |
+
# Download optimized kernels from the Hugging Face hub
|
| 20 |
+
megablocks = get_kernel("kernels-community/megablocks")
|
| 21 |
+
print("MegaBlocks kernel downloaded successfully.")
|
| 22 |
+
|
| 23 |
+
model = megablocks.layers.MegaBlocksMoeMLP()
|
| 24 |
+
model.experts = namedtuple("Experts", ["gate_up_proj", "gate_down_proj", "down_proj", "hidden_size"])
|
| 25 |
+
print("MegaBlocksMoeMLP instance created successfully.")
|
| 26 |
+
|
| 27 |
+
# Config
|
| 28 |
+
ne, hs, isz = 128, 1152, 3072
|
| 29 |
+
|
| 30 |
+
# Router with proper initialization
|
| 31 |
+
model.router = torch.nn.Linear(hs, ne, device="cuda")
|
| 32 |
+
torch.nn.init.kaiming_uniform_(model.router.weight)
|
| 33 |
+
|
| 34 |
+
# Expert layers with realistic weights
|
| 35 |
+
e = model.experts
|
| 36 |
+
e.gate_up_proj = torch.nn.Parameter(torch.randn(ne, hs, isz, device="cuda") * 0.02)
|
| 37 |
+
e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda"))
|
| 38 |
+
e.down_proj = torch.nn.Parameter(torch.randn(ne, 1536, hs, device="cuda") * 0.02)
|
| 39 |
+
e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda"))
|
| 40 |
+
e.hidden_size = hs
|
| 41 |
+
print("Expert layers initialized successfully.")
|
| 42 |
+
|
| 43 |
+
# Test with normalized input
|
| 44 |
+
x = torch.randn(1, 1, hs, device="cuda") * 0.1
|
| 45 |
+
output, expert_weights = model(x)
|
| 46 |
+
print("Model forward pass completed successfully.")
|
| 47 |
+
|
| 48 |
+
print(f"Output shape: {output.shape}")
|
| 49 |
+
print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
|
| 50 |
+
print(f"Output: {output.flatten()[:10]}")
|
| 51 |
+
print(f"Expert weights sum: {expert_weights.sum():.3f}")
|
torch-ext/megablocks/backend/kernels.py
CHANGED
|
@@ -5,6 +5,20 @@ import torch
|
|
| 5 |
import triton
|
| 6 |
import triton.language as tl
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def assert_is_tensor(x, ndim):
|
| 10 |
if x.ndim != ndim:
|
|
|
|
| 5 |
import triton
|
| 6 |
import triton.language as tl
|
| 7 |
|
| 8 |
+
# Stub triton autotune when testing in a env that does not have CUDA
|
| 9 |
+
# this approach preserves the original code but enables testing without a GPU
|
| 10 |
+
if torch.cuda.is_available() is False:
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
warnings.warn("CUDA is not available. Triton autotuning is disabled.")
|
| 14 |
+
|
| 15 |
+
def _no_autotune(*args, **kwargs):
|
| 16 |
+
def deco(fn):
|
| 17 |
+
return fn
|
| 18 |
+
return deco
|
| 19 |
+
|
| 20 |
+
triton.autotune = _no_autotune
|
| 21 |
+
|
| 22 |
|
| 23 |
def assert_is_tensor(x, ndim):
|
| 24 |
if x.ndim != ndim:
|