leonardlin commited on
Commit
867401e
·
1 Parent(s): aeb3812

Gate ROCm grouped_gemm hipBLASLt behind env flag

Browse files
_dev/TODO-gg-linter.md CHANGED
@@ -96,7 +96,7 @@ Both scripts consistently demonstrate:
96
  - ✅ **Fix implemented** — `_allocate_output` now returns a zeroed tensor
97
  - ✅ **Reproduction cases clean** — `_dev/debug-gg-small.py` and `_dev/debug-tensor-copy.py` match the Python reference
98
  - ✅ **hipify behavior understood** — edit `.cu`, not `.hip`, or adjust the build pipeline if we need custom HIP-only changes
99
- - ⚠️ **hipBLASLt path unsuitable** — re-enabling hipBLASLt caused HIP memory access faults on the large expert setups from `tests/ops_test.py`, so we reverted to the cleaned-up FP32 fallback for stability.
100
 
101
  ## Files Modified During Investigation
102
 
 
96
  - ✅ **Fix implemented** — `_allocate_output` now returns a zeroed tensor
97
  - ✅ **Reproduction cases clean** — `_dev/debug-gg-small.py` and `_dev/debug-tensor-copy.py` match the Python reference
98
  - ✅ **hipify behavior understood** — edit `.cu`, not `.hip`, or adjust the build pipeline if we need custom HIP-only changes
99
+ - ⚠️ **hipBLASLt path experimental** — enabling hipBLASLt via `MEGABLOCKS_GG_USE_HIPBLASLT=1` still triggers HIP memory access faults on the large expert setups from `tests/ops_test.py`. Leave the flag off for production; use the FP32 fallback until the hipBLASLt issues are resolved.
100
 
101
  ## Files Modified During Investigation
102
 
_dev/TODO-gg.md CHANGED
@@ -149,7 +149,7 @@ python debug-gg-step-by-step.py # Manual computation verification
149
  - **Misdiagnosed linter**: The perceived “linter” reverting our HIP edits was actually `hipify` regenerating `csrc/grouped_gemm/grouped_gemm.hip` from the CUDA source each time `build.sh` ran. Any HIP-only tweak has to live in `grouped_gemm.cu` (or we adjust the hipify step) to persist.
150
  - **Actual corruption cause**: The ROCm fallback path inside `hipblaslt_gmm_internal` accumulates into the output tensor passed from Python. `_allocate_output` in `torch-ext/megablocks/grouped_gemm/backend.py` created that buffer with `torch.empty`, so the accumulation mixed correct products with uninitialised memory, yielding the 10^17–10^25 explosions.
151
  - **Workaround**: Switching `_allocate_output` to use `torch.zeros` ensures the accumulation starts from a clean slate. After rebuilding, `_dev/debug-gg-small.py` and `_dev/debug-tensor-copy.py` now match the Python reference for all tested expert counts.
152
- - **hipBLASLt evaluation**: We briefly reinstated the hipBLASLt-backed path, but large expert batches triggered HIP memory access faults and the `run-tests.sh` suite aborted in `tests/ops_test.py`. We therefore kept the FP32 fallback in place for now, but stripped the debug prints and ensured it overwrites (rather than accumulates into) the destination tensor.
153
  - **Next steps**: Leave the zero-initialisation in place while exploring a higher-performance HIP kernel; if we need HIP-specific logic, implement it in the `.cu` so hipify preserves the change.
154
 
155
  ```
 
149
  - **Misdiagnosed linter**: The perceived “linter” reverting our HIP edits was actually `hipify` regenerating `csrc/grouped_gemm/grouped_gemm.hip` from the CUDA source each time `build.sh` ran. Any HIP-only tweak has to live in `grouped_gemm.cu` (or we adjust the hipify step) to persist.
150
  - **Actual corruption cause**: The ROCm fallback path inside `hipblaslt_gmm_internal` accumulates into the output tensor passed from Python. `_allocate_output` in `torch-ext/megablocks/grouped_gemm/backend.py` created that buffer with `torch.empty`, so the accumulation mixed correct products with uninitialised memory, yielding the 10^17–10^25 explosions.
151
  - **Workaround**: Switching `_allocate_output` to use `torch.zeros` ensures the accumulation starts from a clean slate. After rebuilding, `_dev/debug-gg-small.py` and `_dev/debug-tensor-copy.py` now match the Python reference for all tested expert counts.
152
+ - **hipBLASLt evaluation**: We briefly reinstated the hipBLASLt-backed path, but large expert batches triggered HIP memory access faults and the `run-tests.sh` suite aborted in `tests/ops_test.py`. We therefore kept the FP32 fallback in place for now, gated by the `MEGABLOCKS_GG_USE_HIPBLASLT` env var so we can experiment with hipBLASLt when desired, while production defaults to the stable FP32 path that overwrites (rather than accumulates into) the destination tensor.
153
  - **Next steps**: Leave the zero-initialisation in place while exploring a higher-performance HIP kernel; if we need HIP-specific logic, implement it in the `.cu` so hipify preserves the change.
154
 
155
  ```
_dev/TODO-hip.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HIP Grouped GEMM Status (2025-09-18)
2
+
3
+ ## Current toggle
4
+ - Set `MEGABLOCKS_GG_USE_HIPBLASLT=1` to force the ROCm build to run the hipBLASLt backend instead of the FP32 fallback in `hipblaslt_gmm_internal`.
5
+ - Without the flag the code uses the stable FP32 `torch::matmul` path that overwrites the destination buffer.
6
+
7
+ ## What works with hipBLASLt enabled
8
+ - `_dev/debug-gg-small.py`, `_dev/debug-tensor-copy.py`, and `_dev/debug-gg-detailed.py` finish with finite outputs (differences are within ~1e-3..1e-2 due to BF16).
9
+ - `python -m pytest tests/test_gg.py -q` passes with the flag set.
10
+
11
+ ## Known failures
12
+ - `PYTHONPATH=build/... MEGABLOCKS_GG_USE_HIPBLASLT=1 python -m pytest tests/ops_test.py -q` aborts with a HIP memory access fault (`Memory access fault by GPU node-2` during `OpsTest.testGroupedGemm_FixedSizes`).
13
+ - The same failure occurs early when the test suite is run via `run-tests.sh`, so hipBLASLt is not yet production-ready.
14
+
15
+ ## Next steps
16
+ - Reproduce the fault in isolation (likely the large `(z=16, m=128, k=128, n=128)` cases) and inspect the arguments passed into `hipblaslt_run_matmul` (leading dimensions/layout).
17
+ - Investigate whether hipBLASLt requires column-major layouts or non-zero workspace to handle the grouped GEMM shapes.
18
+ - Consider hybrid strategy: attempt hipBLASLt per expert and fall back to FP32 for shapes that exceed stability thresholds (e.g., by catching `hipblaslt_run_matmul` errors once we can reliably detect them).
19
+ - Once hipBLASLt is stable, tighten tolerances/grad checks in `tests/test_gg.py` and re-enable the high-performance path by default.
csrc/grouped_gemm/grouped_gemm.cu CHANGED
@@ -7,10 +7,35 @@
7
  #include <hipblaslt/hipblaslt.h>
8
  #include <torch/autograd.h>
9
  #include <vector>
 
 
 
 
10
 
11
  namespace grouped_gemm {
12
  namespace {
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  inline void hipblaslt_check(hipblasStatus_t status, const char* expr) {
15
  TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "hipBLASLt call failed with status ", status, " when executing ", expr);
16
  }
@@ -152,6 +177,7 @@ torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
152
 
153
  auto device = a.device();
154
  auto dtype = a.scalar_type();
 
155
 
156
  const auto counts_ptr = batch_sizes.data_ptr<int64_t>();
157
  const int64_t num_experts = batch_sizes.size(0);
@@ -174,28 +200,64 @@ torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
174
 
175
  auto b_contig = b.contiguous();
176
 
177
- int64_t start = 0;
178
- for (int64_t expert = 0; expert < num_experts; ++expert) {
179
- const int64_t end = prefix[expert];
180
- const int64_t rows = end - start;
181
- auto out_chunk = out.select(0, expert);
182
- if (rows == 0) {
183
- out_chunk.zero_();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  start = end;
185
- continue;
186
  }
187
-
188
- auto a_slice = a.narrow(0, start, rows);
189
- auto b_slice = b_contig.narrow(0, start, rows);
190
-
191
- auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
192
- auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
193
-
194
- auto prod = torch::matmul(a_f32.transpose(0, 1), b_f32);
195
- auto prod_bf16 = prod.to(dtype);
196
-
197
- out_chunk.copy_(prod_bf16);
198
- start = end;
199
  }
200
  return out;
201
  }
@@ -208,6 +270,104 @@ torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
208
 
209
  auto b_contig = b.contiguous();
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  int64_t start = 0;
212
  for (int64_t expert = 0; expert < num_experts; ++expert) {
213
  const int64_t end = prefix[expert];
@@ -223,42 +383,12 @@ torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
223
  auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
224
  auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
225
 
226
- auto prod = torch::matmul(a_f32, b_f32.transpose(0, 1));
227
  auto prod_bf16 = prod.to(dtype);
228
 
229
  out_chunk.copy_(prod_bf16);
230
  start = end;
231
  }
232
- return out;
233
- }
234
-
235
- const int64_t hidden_out = a.size(1);
236
- const int64_t hidden_in = b.size(2);
237
- out = c_opt.value_or(torch::empty({tokens, hidden_in}, a.options()));
238
- TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");
239
-
240
- auto b_contig = b.contiguous();
241
-
242
- int64_t start = 0;
243
- for (int64_t expert = 0; expert < num_experts; ++expert) {
244
- const int64_t end = prefix[expert];
245
- const int64_t rows = end - start;
246
- if (rows == 0) {
247
- start = end;
248
- continue;
249
- }
250
- auto a_slice = a.narrow(0, start, rows);
251
- auto b_slice = b_contig.select(0, expert);
252
- auto out_chunk = out.narrow(0, start, rows);
253
-
254
- auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
255
- auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
256
-
257
- auto prod = torch::matmul(a_f32, b_f32);
258
- auto prod_bf16 = prod.to(dtype);
259
-
260
- out_chunk.copy_(prod_bf16);
261
- start = end;
262
  }
263
  return out;
264
  }
 
7
  #include <hipblaslt/hipblaslt.h>
8
  #include <torch/autograd.h>
9
  #include <vector>
10
+ #include <algorithm>
11
+ #include <cctype>
12
+ #include <cstdlib>
13
+ #include <string>
14
 
15
  namespace grouped_gemm {
16
  namespace {
17
 
18
+ // Experimental: toggled via MEGABLOCKS_GG_USE_HIPBLASLT=1. This flag is
19
+ // intentionally off by default because the hipBLASLt path still fails on the
20
+ // largest `tests/ops_test.py` configurations.
21
+ bool use_hipblaslt_backend() {
22
+ static int cached = [] {
23
+ const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT");
24
+ if (raw == nullptr) {
25
+ return 0;
26
+ }
27
+ std::string value(raw);
28
+ std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
29
+ return static_cast<char>(std::tolower(c));
30
+ });
31
+ if (value == "1" || value == "true" || value == "yes" || value == "on") {
32
+ return 1;
33
+ }
34
+ return 0;
35
+ }();
36
+ return cached == 1;
37
+ }
38
+
39
  inline void hipblaslt_check(hipblasStatus_t status, const char* expr) {
40
  TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "hipBLASLt call failed with status ", status, " when executing ", expr);
41
  }
 
177
 
178
  auto device = a.device();
179
  auto dtype = a.scalar_type();
180
+ const bool use_hip = use_hipblaslt_backend();
181
 
182
  const auto counts_ptr = batch_sizes.data_ptr<int64_t>();
183
  const int64_t num_experts = batch_sizes.size(0);
 
200
 
201
  auto b_contig = b.contiguous();
202
 
203
+ if (use_hip) {
204
+ int64_t start = 0;
205
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
206
+ const int64_t end = prefix[expert];
207
+ const int64_t rows = end - start;
208
+ auto out_chunk = out.select(0, expert);
209
+ if (rows == 0) {
210
+ out_chunk.zero_();
211
+ start = end;
212
+ continue;
213
+ }
214
+
215
+ auto a_chunk = a.narrow(0, start, rows).contiguous();
216
+ auto b_chunk = b_contig.narrow(0, start, rows).contiguous();
217
+
218
+ hipblaslt_run_matmul(a_chunk.data_ptr(),
219
+ b_chunk.data_ptr(),
220
+ out_chunk.data_ptr(),
221
+ out_chunk.data_ptr(),
222
+ rows,
223
+ hidden_in,
224
+ rows,
225
+ hidden_out,
226
+ hidden_in,
227
+ hidden_out,
228
+ hidden_in,
229
+ hidden_out,
230
+ hidden_out,
231
+ hidden_out,
232
+ HIPBLAS_OP_T,
233
+ HIPBLAS_OP_N,
234
+ /*accumulate=*/false);
235
+ start = end;
236
+ }
237
+ } else {
238
+ int64_t start = 0;
239
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
240
+ const int64_t end = prefix[expert];
241
+ const int64_t rows = end - start;
242
+ auto out_chunk = out.select(0, expert);
243
+ if (rows == 0) {
244
+ out_chunk.zero_();
245
+ start = end;
246
+ continue;
247
+ }
248
+
249
+ auto a_slice = a.narrow(0, start, rows);
250
+ auto b_slice = b_contig.narrow(0, start, rows);
251
+
252
+ auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
253
+ auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
254
+
255
+ auto prod = torch::matmul(a_f32.transpose(0, 1), b_f32);
256
+ auto prod_bf16 = prod.to(dtype);
257
+
258
+ out_chunk.copy_(prod_bf16);
259
  start = end;
 
260
  }
 
 
 
 
 
 
 
 
 
 
 
 
261
  }
262
  return out;
263
  }
 
270
 
271
  auto b_contig = b.contiguous();
272
 
273
+ if (use_hip) {
274
+ int64_t start = 0;
275
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
276
+ const int64_t end = prefix[expert];
277
+ const int64_t rows = end - start;
278
+ if (rows == 0) {
279
+ start = end;
280
+ continue;
281
+ }
282
+ auto a_chunk = a.narrow(0, start, rows).contiguous();
283
+ auto b_chunk = b_contig.select(0, expert).contiguous();
284
+ auto out_chunk = out.narrow(0, start, rows);
285
+
286
+ hipblaslt_run_matmul(a_chunk.data_ptr(),
287
+ b_chunk.data_ptr(),
288
+ out_chunk.data_ptr(),
289
+ out_chunk.data_ptr(),
290
+ rows,
291
+ hidden_in,
292
+ hidden_out,
293
+ hidden_in,
294
+ rows,
295
+ hidden_out,
296
+ hidden_in,
297
+ hidden_in,
298
+ hidden_out,
299
+ hidden_out,
300
+ HIPBLAS_OP_N,
301
+ HIPBLAS_OP_T,
302
+ /*accumulate=*/false);
303
+ start = end;
304
+ }
305
+ } else {
306
+ int64_t start = 0;
307
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
308
+ const int64_t end = prefix[expert];
309
+ const int64_t rows = end - start;
310
+ if (rows == 0) {
311
+ start = end;
312
+ continue;
313
+ }
314
+ auto a_slice = a.narrow(0, start, rows);
315
+ auto b_slice = b_contig.select(0, expert);
316
+ auto out_chunk = out.narrow(0, start, rows);
317
+
318
+ auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
319
+ auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
320
+
321
+ auto prod = torch::matmul(a_f32, b_f32.transpose(0, 1));
322
+ auto prod_bf16 = prod.to(dtype);
323
+
324
+ out_chunk.copy_(prod_bf16);
325
+ start = end;
326
+ }
327
+ }
328
+ return out;
329
+ }
330
+
331
+ const int64_t hidden_out = a.size(1);
332
+ const int64_t hidden_in = b.size(2);
333
+ out = c_opt.value_or(torch::empty({tokens, hidden_in}, a.options()));
334
+ TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");
335
+
336
+ auto b_contig = b.contiguous();
337
+
338
+ if (use_hip) {
339
+ int64_t start = 0;
340
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
341
+ const int64_t end = prefix[expert];
342
+ const int64_t rows = end - start;
343
+ if (rows == 0) {
344
+ start = end;
345
+ continue;
346
+ }
347
+ auto a_chunk = a.narrow(0, start, rows).contiguous();
348
+ auto b_chunk = b_contig.select(0, expert).contiguous();
349
+ auto out_chunk = out.narrow(0, start, rows);
350
+
351
+ hipblaslt_run_matmul(a_chunk.data_ptr(),
352
+ b_chunk.data_ptr(),
353
+ out_chunk.data_ptr(),
354
+ out_chunk.data_ptr(),
355
+ rows,
356
+ hidden_out,
357
+ hidden_out,
358
+ hidden_in,
359
+ rows,
360
+ hidden_in,
361
+ hidden_out,
362
+ hidden_in,
363
+ hidden_in,
364
+ hidden_in,
365
+ HIPBLAS_OP_N,
366
+ HIPBLAS_OP_N,
367
+ /*accumulate=*/false);
368
+ start = end;
369
+ }
370
+ } else {
371
  int64_t start = 0;
372
  for (int64_t expert = 0; expert < num_experts; ++expert) {
373
  const int64_t end = prefix[expert];
 
383
  auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
384
  auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
385
 
386
+ auto prod = torch::matmul(a_f32, b_f32);
387
  auto prod_bf16 = prod.to(dtype);
388
 
389
  out_chunk.copy_(prod_bf16);
390
  start = end;
391
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  }
393
  return out;
394
  }
csrc/grouped_gemm/grouped_gemm.hip CHANGED
@@ -9,10 +9,32 @@
9
  #include <hipblaslt/hipblaslt.h>
10
  #include <torch/autograd.h>
11
  #include <vector>
 
 
 
 
12
 
13
  namespace grouped_gemm {
14
  namespace {
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  inline void hipblaslt_check(hipblasStatus_t status, const char* expr) {
17
  TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "hipBLASLt call failed with status ", status, " when executing ", expr);
18
  }
@@ -154,6 +176,7 @@ torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
154
 
155
  auto device = a.device();
156
  auto dtype = a.scalar_type();
 
157
 
158
  const auto counts_ptr = batch_sizes.data_ptr<int64_t>();
159
  const int64_t num_experts = batch_sizes.size(0);
@@ -176,28 +199,64 @@ torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
176
 
177
  auto b_contig = b.contiguous();
178
 
179
- int64_t start = 0;
180
- for (int64_t expert = 0; expert < num_experts; ++expert) {
181
- const int64_t end = prefix[expert];
182
- const int64_t rows = end - start;
183
- auto out_chunk = out.select(0, expert);
184
- if (rows == 0) {
185
- out_chunk.zero_();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  start = end;
187
- continue;
188
  }
189
-
190
- auto a_slice = a.narrow(0, start, rows);
191
- auto b_slice = b_contig.narrow(0, start, rows);
192
-
193
- auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
194
- auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
195
-
196
- auto prod = torch::matmul(a_f32.transpose(0, 1), b_f32);
197
- auto prod_bf16 = prod.to(dtype);
198
-
199
- out_chunk.copy_(prod_bf16);
200
- start = end;
201
  }
202
  return out;
203
  }
@@ -210,6 +269,104 @@ torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
210
 
211
  auto b_contig = b.contiguous();
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  int64_t start = 0;
214
  for (int64_t expert = 0; expert < num_experts; ++expert) {
215
  const int64_t end = prefix[expert];
@@ -225,42 +382,12 @@ torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
225
  auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
226
  auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
227
 
228
- auto prod = torch::matmul(a_f32, b_f32.transpose(0, 1));
229
  auto prod_bf16 = prod.to(dtype);
230
 
231
  out_chunk.copy_(prod_bf16);
232
  start = end;
233
  }
234
- return out;
235
- }
236
-
237
- const int64_t hidden_out = a.size(1);
238
- const int64_t hidden_in = b.size(2);
239
- out = c_opt.value_or(torch::empty({tokens, hidden_in}, a.options()));
240
- TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");
241
-
242
- auto b_contig = b.contiguous();
243
-
244
- int64_t start = 0;
245
- for (int64_t expert = 0; expert < num_experts; ++expert) {
246
- const int64_t end = prefix[expert];
247
- const int64_t rows = end - start;
248
- if (rows == 0) {
249
- start = end;
250
- continue;
251
- }
252
- auto a_slice = a.narrow(0, start, rows);
253
- auto b_slice = b_contig.select(0, expert);
254
- auto out_chunk = out.narrow(0, start, rows);
255
-
256
- auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
257
- auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
258
-
259
- auto prod = torch::matmul(a_f32, b_f32);
260
- auto prod_bf16 = prod.to(dtype);
261
-
262
- out_chunk.copy_(prod_bf16);
263
- start = end;
264
  }
265
  return out;
266
  }
 
9
  #include <hipblaslt/hipblaslt.h>
10
  #include <torch/autograd.h>
11
  #include <vector>
12
+ #include <algorithm>
13
+ #include <cctype>
14
+ #include <cstdlib>
15
+ #include <string>
16
 
17
  namespace grouped_gemm {
18
  namespace {
19
 
20
+ bool use_hipblaslt_backend() {
21
+ static int cached = [] {
22
+ const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT");
23
+ if (raw == nullptr) {
24
+ return 0;
25
+ }
26
+ std::string value(raw);
27
+ std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
28
+ return static_cast<char>(std::tolower(c));
29
+ });
30
+ if (value == "1" || value == "true" || value == "yes" || value == "on") {
31
+ return 1;
32
+ }
33
+ return 0;
34
+ }();
35
+ return cached == 1;
36
+ }
37
+
38
  inline void hipblaslt_check(hipblasStatus_t status, const char* expr) {
39
  TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "hipBLASLt call failed with status ", status, " when executing ", expr);
40
  }
 
176
 
177
  auto device = a.device();
178
  auto dtype = a.scalar_type();
179
+ const bool use_hip = use_hipblaslt_backend();
180
 
181
  const auto counts_ptr = batch_sizes.data_ptr<int64_t>();
182
  const int64_t num_experts = batch_sizes.size(0);
 
199
 
200
  auto b_contig = b.contiguous();
201
 
202
+ if (use_hip) {
203
+ int64_t start = 0;
204
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
205
+ const int64_t end = prefix[expert];
206
+ const int64_t rows = end - start;
207
+ auto out_chunk = out.select(0, expert);
208
+ if (rows == 0) {
209
+ out_chunk.zero_();
210
+ start = end;
211
+ continue;
212
+ }
213
+
214
+ auto a_chunk = a.narrow(0, start, rows).contiguous();
215
+ auto b_chunk = b_contig.narrow(0, start, rows).contiguous();
216
+
217
+ hipblaslt_run_matmul(a_chunk.data_ptr(),
218
+ b_chunk.data_ptr(),
219
+ out_chunk.data_ptr(),
220
+ out_chunk.data_ptr(),
221
+ rows,
222
+ hidden_in,
223
+ rows,
224
+ hidden_out,
225
+ hidden_in,
226
+ hidden_out,
227
+ hidden_in,
228
+ hidden_out,
229
+ hidden_out,
230
+ hidden_out,
231
+ HIPBLAS_OP_T,
232
+ HIPBLAS_OP_N,
233
+ /*accumulate=*/false);
234
+ start = end;
235
+ }
236
+ } else {
237
+ int64_t start = 0;
238
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
239
+ const int64_t end = prefix[expert];
240
+ const int64_t rows = end - start;
241
+ auto out_chunk = out.select(0, expert);
242
+ if (rows == 0) {
243
+ out_chunk.zero_();
244
+ start = end;
245
+ continue;
246
+ }
247
+
248
+ auto a_slice = a.narrow(0, start, rows);
249
+ auto b_slice = b_contig.narrow(0, start, rows);
250
+
251
+ auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
252
+ auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
253
+
254
+ auto prod = torch::matmul(a_f32.transpose(0, 1), b_f32);
255
+ auto prod_bf16 = prod.to(dtype);
256
+
257
+ out_chunk.copy_(prod_bf16);
258
  start = end;
 
259
  }
 
 
 
 
 
 
 
 
 
 
 
 
260
  }
261
  return out;
262
  }
 
269
 
270
  auto b_contig = b.contiguous();
271
 
272
+ if (use_hip) {
273
+ int64_t start = 0;
274
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
275
+ const int64_t end = prefix[expert];
276
+ const int64_t rows = end - start;
277
+ if (rows == 0) {
278
+ start = end;
279
+ continue;
280
+ }
281
+ auto a_chunk = a.narrow(0, start, rows).contiguous();
282
+ auto b_chunk = b_contig.select(0, expert).contiguous();
283
+ auto out_chunk = out.narrow(0, start, rows);
284
+
285
+ hipblaslt_run_matmul(a_chunk.data_ptr(),
286
+ b_chunk.data_ptr(),
287
+ out_chunk.data_ptr(),
288
+ out_chunk.data_ptr(),
289
+ rows,
290
+ hidden_in,
291
+ hidden_out,
292
+ hidden_in,
293
+ rows,
294
+ hidden_out,
295
+ hidden_in,
296
+ hidden_in,
297
+ hidden_out,
298
+ hidden_out,
299
+ HIPBLAS_OP_N,
300
+ HIPBLAS_OP_T,
301
+ /*accumulate=*/false);
302
+ start = end;
303
+ }
304
+ } else {
305
+ int64_t start = 0;
306
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
307
+ const int64_t end = prefix[expert];
308
+ const int64_t rows = end - start;
309
+ if (rows == 0) {
310
+ start = end;
311
+ continue;
312
+ }
313
+ auto a_slice = a.narrow(0, start, rows);
314
+ auto b_slice = b_contig.select(0, expert);
315
+ auto out_chunk = out.narrow(0, start, rows);
316
+
317
+ auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
318
+ auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
319
+
320
+ auto prod = torch::matmul(a_f32, b_f32.transpose(0, 1));
321
+ auto prod_bf16 = prod.to(dtype);
322
+
323
+ out_chunk.copy_(prod_bf16);
324
+ start = end;
325
+ }
326
+ }
327
+ return out;
328
+ }
329
+
330
+ const int64_t hidden_out = a.size(1);
331
+ const int64_t hidden_in = b.size(2);
332
+ out = c_opt.value_or(torch::empty({tokens, hidden_in}, a.options()));
333
+ TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");
334
+
335
+ auto b_contig = b.contiguous();
336
+
337
+ if (use_hip) {
338
+ int64_t start = 0;
339
+ for (int64_t expert = 0; expert < num_experts; ++expert) {
340
+ const int64_t end = prefix[expert];
341
+ const int64_t rows = end - start;
342
+ if (rows == 0) {
343
+ start = end;
344
+ continue;
345
+ }
346
+ auto a_chunk = a.narrow(0, start, rows).contiguous();
347
+ auto b_chunk = b_contig.select(0, expert).contiguous();
348
+ auto out_chunk = out.narrow(0, start, rows);
349
+
350
+ hipblaslt_run_matmul(a_chunk.data_ptr(),
351
+ b_chunk.data_ptr(),
352
+ out_chunk.data_ptr(),
353
+ out_chunk.data_ptr(),
354
+ rows,
355
+ hidden_out,
356
+ hidden_out,
357
+ hidden_in,
358
+ rows,
359
+ hidden_in,
360
+ hidden_out,
361
+ hidden_in,
362
+ hidden_in,
363
+ hidden_in,
364
+ HIPBLAS_OP_N,
365
+ HIPBLAS_OP_N,
366
+ /*accumulate=*/false);
367
+ start = end;
368
+ }
369
+ } else {
370
  int64_t start = 0;
371
  for (int64_t expert = 0; expert < num_experts; ++expert) {
372
  const int64_t end = prefix[expert];
 
382
  auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
383
  auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
384
 
385
+ auto prod = torch::matmul(a_f32, b_f32);
386
  auto prod_bf16 = prod.to(dtype);
387
 
388
  out_chunk.copy_(prod_bf16);
389
  start = end;
390
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  }
392
  return out;
393
  }