Spaces:
Running
on
T4
Running
on
T4
sparkleman
commited on
Commit
·
ff3952a
1
Parent(s):
05b6df6
UPDATE: Merge cuda core from BlinkDL/RWKV-Gradio-1
Browse files- Dockerfile +1 -1
- app.py +2 -1
- cuda/gemm_fp16_cublas.cpp +75 -0
- cuda/operators.cu +246 -0
- cuda/rwkv5.cu +88 -0
- cuda/rwkv5_op.cpp +34 -0
- cuda/rwkv6.cu +87 -0
- cuda/rwkv6_op.cpp +34 -0
- cuda/wrapper.cpp +141 -0
- pyproject.toml +2 -0
- uv.lock +28 -0
Dockerfile
CHANGED
|
@@ -26,4 +26,4 @@ COPY --chown=user . $HOME/app
|
|
| 26 |
|
| 27 |
RUN uv sync --frozen --extra cu124
|
| 28 |
|
| 29 |
-
CMD ["uv","run","app.py","--strategy","cuda fp16","--model_title","RWKV-x070-World-0.1B-v2.8-20241210-ctx4096","--download_repo_id","BlinkDL/rwkv-7-world","--host","0.0.0.0","--port","7860"]
|
|
|
|
| 26 |
|
| 27 |
RUN uv sync --frozen --extra cu124
|
| 28 |
|
| 29 |
+
CMD ["uv","run","app.py","--strategy","cuda fp16","--model_title","RWKV-x070-World-0.1B-v2.8-20241210-ctx4096","--download_repo_id","BlinkDL/rwkv-7-world","--host","0.0.0.0","--port","7860","--RWKV_CUDA_ON","True"]
|
app.py
CHANGED
|
@@ -26,6 +26,7 @@ class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=Tr
|
|
| 26 |
description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
|
| 27 |
)
|
| 28 |
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
CONFIG = Config()
|
|
@@ -51,7 +52,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
|
| 51 |
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
|
| 52 |
os.environ["RWKV_JIT_ON"] = "1"
|
| 53 |
os.environ["RWKV_CUDA_ON"] = (
|
| 54 |
-
"
|
| 55 |
)
|
| 56 |
|
| 57 |
from rwkv.model import RWKV
|
|
|
|
| 26 |
description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
|
| 27 |
)
|
| 28 |
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
| 29 |
+
RWKV_CUDA_ON:bool = Field(False, description="`True` to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!")
|
| 30 |
|
| 31 |
|
| 32 |
CONFIG = Config()
|
|
|
|
| 52 |
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
|
| 53 |
os.environ["RWKV_JIT_ON"] = "1"
|
| 54 |
os.environ["RWKV_CUDA_ON"] = (
|
| 55 |
+
"1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
|
| 56 |
)
|
| 57 |
|
| 58 |
from rwkv.model import RWKV
|
cuda/gemm_fp16_cublas.cpp
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <cublas_v2.h>
|
| 2 |
+
#include <cuda.h>
|
| 3 |
+
#include <cuda_fp16.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
#include <torch/extension.h>
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 8 |
+
|
| 9 |
+
#define CUBLAS_CHECK(condition) \
|
| 10 |
+
for (cublasStatus_t _cublas_check_status = (condition); \
|
| 11 |
+
_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
|
| 12 |
+
throw std::runtime_error("cuBLAS error " + \
|
| 13 |
+
std::to_string(_cublas_check_status) + " at " + \
|
| 14 |
+
std::to_string(__LINE__));
|
| 15 |
+
|
| 16 |
+
#define CUDA_CHECK(condition) \
|
| 17 |
+
for (cudaError_t _cuda_check_status = (condition); \
|
| 18 |
+
_cuda_check_status != cudaSuccess;) \
|
| 19 |
+
throw std::runtime_error( \
|
| 20 |
+
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
| 21 |
+
" at " + std::to_string(__LINE__));
|
| 22 |
+
|
| 23 |
+
/*
|
| 24 |
+
NOTE: blas gemm is column-major by default, but we need row-major output.
|
| 25 |
+
The data of row-major, transposed matrix is exactly the same as the
|
| 26 |
+
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
| 27 |
+
*/
|
| 28 |
+
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
| 29 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
| 30 |
+
const auto cuda_data_type = CUDA_R_16F;
|
| 31 |
+
const auto cuda_c_data_type =
|
| 32 |
+
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
| 33 |
+
const auto compute_type = CUDA_R_32F;
|
| 34 |
+
const float sp_alpha = 1.f;
|
| 35 |
+
// swap a and b, and use CUBLAS_OP_N. see the notes above
|
| 36 |
+
std::swap(a, b);
|
| 37 |
+
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
| 38 |
+
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
| 39 |
+
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
|
| 40 |
+
// negative axis is used because of the existence of batch matmul.
|
| 41 |
+
const int m = a.size(-1);
|
| 42 |
+
const int k = a.size(-2);
|
| 43 |
+
const int n = b.size(-2);
|
| 44 |
+
const int cublas_lda = m;
|
| 45 |
+
const int cublas_ldb = k;
|
| 46 |
+
const int cublas_ldc = m;
|
| 47 |
+
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
|
| 48 |
+
|
| 49 |
+
#if CUDA_VERSION >= 11000
|
| 50 |
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
| 51 |
+
#else
|
| 52 |
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
| 53 |
+
#endif
|
| 54 |
+
const float sp_beta = 0.f;
|
| 55 |
+
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
|
| 56 |
+
CUBLAS_CHECK(cublasGemmEx(
|
| 57 |
+
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
|
| 58 |
+
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
|
| 59 |
+
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
|
| 60 |
+
compute_type, algo));
|
| 61 |
+
} else {
|
| 62 |
+
// batch matmul
|
| 63 |
+
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
|
| 64 |
+
|
| 65 |
+
const long long int cublas_stride_a = m * k;
|
| 66 |
+
const long long int cublas_stride_b = k * n;
|
| 67 |
+
const long long int cublas_stride_c = m * n;
|
| 68 |
+
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
|
| 69 |
+
cublas_handle, cublas_trans_a, cublas_trans_b, m,
|
| 70 |
+
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
|
| 71 |
+
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
|
| 72 |
+
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
|
| 73 |
+
a.size(0), compute_type, algo));
|
| 74 |
+
}
|
| 75 |
+
}
|
cuda/operators.cu
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include "ATen/ATen.h"
|
| 4 |
+
#include <cuda_fp16.h>
|
| 5 |
+
#define MIN_VALUE (-1e38)
|
| 6 |
+
typedef at::Half fp16;
|
| 7 |
+
__half *cast(fp16 *ptr) {
|
| 8 |
+
return reinterpret_cast<__half *>(ptr);
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
template <typename F>
|
| 12 |
+
__global__ void kernel_wkv_forward(const int B, const int T, const int C,
|
| 13 |
+
const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
| 14 |
+
F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
|
| 15 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 16 |
+
const int _b = idx / C;
|
| 17 |
+
const int _c = idx % C;
|
| 18 |
+
const int _offset = _b * T * C + _c;
|
| 19 |
+
const int _state_offset = _b * C + _c;
|
| 20 |
+
|
| 21 |
+
float u = _u[_c];
|
| 22 |
+
float w = _w[_c];
|
| 23 |
+
const F *__restrict__ const k = _k + _offset;
|
| 24 |
+
const F *__restrict__ const v = _v + _offset;
|
| 25 |
+
F *__restrict__ const y = _y + _offset;
|
| 26 |
+
|
| 27 |
+
float aa = _aa[_state_offset];
|
| 28 |
+
float bb = _bb[_state_offset];
|
| 29 |
+
float pp = _pp[_state_offset];
|
| 30 |
+
for (int i = 0; i < T; i++) {
|
| 31 |
+
const int ii = i * C;
|
| 32 |
+
const float kk = float(k[ii]);
|
| 33 |
+
const float vv = float(v[ii]);
|
| 34 |
+
float ww = u + kk;
|
| 35 |
+
float p = max(pp, ww);
|
| 36 |
+
float e1 = exp(pp - p);
|
| 37 |
+
float e2 = exp(ww - p);
|
| 38 |
+
y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
| 39 |
+
ww = w + pp;
|
| 40 |
+
p = max(ww, kk);
|
| 41 |
+
e1 = exp(ww - p);
|
| 42 |
+
e2 = exp(kk - p);
|
| 43 |
+
aa = e1 * aa + e2 * vv;
|
| 44 |
+
bb = e1 * bb + e2;
|
| 45 |
+
pp = p;
|
| 46 |
+
}
|
| 47 |
+
_aa[_state_offset] = aa;
|
| 48 |
+
_bb[_state_offset] = bb;
|
| 49 |
+
_pp[_state_offset] = pp;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
template <typename F>
|
| 53 |
+
void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
|
| 54 |
+
dim3 threadsPerBlock( min(C, 32) );
|
| 55 |
+
assert(B * C % threadsPerBlock.x == 0);
|
| 56 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
| 57 |
+
kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template void cuda_wkv_forward<fp16>(
|
| 61 |
+
int B, int T, int C,
|
| 62 |
+
float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
|
| 63 |
+
float *aa, float *bb, float *pp);
|
| 64 |
+
template void cuda_wkv_forward<float>(
|
| 65 |
+
int B, int T, int C,
|
| 66 |
+
float *w, float *u, float *k, float *v, float *y,
|
| 67 |
+
float *aa, float *bb, float *pp);
|
| 68 |
+
|
| 69 |
+
__global__ void kernel_mm_seq_fp32i8(
|
| 70 |
+
const int B, const int N, const int M,
|
| 71 |
+
const float *__restrict__ const x, const int x_stride,
|
| 72 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
| 73 |
+
const float *__restrict__ const mx,
|
| 74 |
+
const float *__restrict__ const rx,
|
| 75 |
+
const float *__restrict__ const my,
|
| 76 |
+
const float *__restrict__ const ry,
|
| 77 |
+
float *__restrict__ const y, const int y_stride) {
|
| 78 |
+
|
| 79 |
+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
| 80 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 81 |
+
|
| 82 |
+
if (i < B && k < M) {
|
| 83 |
+
float y_local = 0;
|
| 84 |
+
for (int j = 0; j < N; ++j) {
|
| 85 |
+
y_local += x[i * x_stride + j] * (
|
| 86 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
| 87 |
+
* rx[k] * ry[j] + mx[k] + my[j]
|
| 88 |
+
);
|
| 89 |
+
}
|
| 90 |
+
y[i * y_stride + k] = y_local;
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
template <typename F>
|
| 95 |
+
void cuda_mm8_seq(int B, int N, int M,
|
| 96 |
+
F *x, int x_stride,
|
| 97 |
+
uint8_t *w, int w_stride,
|
| 98 |
+
F *mx, F *rx,
|
| 99 |
+
F *my, F *ry,
|
| 100 |
+
F *y, int y_stride);
|
| 101 |
+
|
| 102 |
+
template <>
|
| 103 |
+
void cuda_mm8_seq<float>(int B, int N, int M,
|
| 104 |
+
float *x, int x_stride,
|
| 105 |
+
uint8_t *w, int w_stride,
|
| 106 |
+
float *mx, float *rx,
|
| 107 |
+
float *my, float *ry,
|
| 108 |
+
float *y, int y_stride) {
|
| 109 |
+
dim3 blockSize(1, 128);
|
| 110 |
+
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
| 111 |
+
kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
|
| 112 |
+
B, N, M, x, x_stride, w, w_stride,
|
| 113 |
+
mx, rx, my, ry, y, y_stride);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
__global__ void kernel_mm_seq_fp16i8(
|
| 117 |
+
const int B, const int N, const int M,
|
| 118 |
+
const __half *__restrict__ const x, const int x_stride,
|
| 119 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
| 120 |
+
const __half *__restrict__ const mx,
|
| 121 |
+
const __half *__restrict__ const rx,
|
| 122 |
+
const __half *__restrict__ const my,
|
| 123 |
+
const __half *__restrict__ const ry,
|
| 124 |
+
__half *__restrict__ const y, const int y_stride) {
|
| 125 |
+
|
| 126 |
+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
| 127 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 128 |
+
|
| 129 |
+
if (i < B && k < M) {
|
| 130 |
+
float y_local = 0;
|
| 131 |
+
for (int j = 0; j < N; ++j) {
|
| 132 |
+
y_local += __half2float(x[i * x_stride + j]) * (
|
| 133 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
| 134 |
+
* __half2float(rx[k]) * __half2float(ry[j])
|
| 135 |
+
+ __half2float(mx[k]) + __half2float(my[j])
|
| 136 |
+
);
|
| 137 |
+
}
|
| 138 |
+
y[i * y_stride + k] = __float2half(y_local);
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
template <>
|
| 143 |
+
void cuda_mm8_seq<fp16>(int B, int N, int M,
|
| 144 |
+
fp16 *x, int x_stride,
|
| 145 |
+
uint8_t *w, int w_stride,
|
| 146 |
+
fp16 *mx, fp16 *rx,
|
| 147 |
+
fp16 *my, fp16 *ry,
|
| 148 |
+
fp16 *y, int y_stride) {
|
| 149 |
+
dim3 blockSize(1, 128);
|
| 150 |
+
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
| 151 |
+
kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
|
| 152 |
+
B, N, M, cast(x), x_stride, w, w_stride,
|
| 153 |
+
cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
#define MM8_ONE_JSPLIT 24
|
| 157 |
+
#define MM8_ONE_TILE 1024
|
| 158 |
+
|
| 159 |
+
__global__ void kernel_mm_one_fp32i8(
|
| 160 |
+
const int N, const int M,
|
| 161 |
+
const float *__restrict__ const x,
|
| 162 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
| 163 |
+
const float *__restrict__ const mx,
|
| 164 |
+
const float *__restrict__ const rx,
|
| 165 |
+
const float *__restrict__ const my,
|
| 166 |
+
const float *__restrict__ const ry,
|
| 167 |
+
float *__restrict__ const y) {
|
| 168 |
+
|
| 169 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 170 |
+
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
| 171 |
+
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
| 172 |
+
|
| 173 |
+
if (k < M) {
|
| 174 |
+
float y_local = 0;
|
| 175 |
+
for (int j = j0; j < j1; ++j) {
|
| 176 |
+
y_local += x[j] * (
|
| 177 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
| 178 |
+
* rx[k] * ry[j] + mx[k] + my[j]
|
| 179 |
+
);
|
| 180 |
+
}
|
| 181 |
+
atomicAdd(&y[k], y_local);
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
template <typename F>
|
| 186 |
+
void cuda_mm8_one(int N, int M,
|
| 187 |
+
F *x,
|
| 188 |
+
uint8_t *w, int w_stride,
|
| 189 |
+
F *mx, F *rx,
|
| 190 |
+
F *my, F *ry,
|
| 191 |
+
float *y);
|
| 192 |
+
|
| 193 |
+
template <>
|
| 194 |
+
void cuda_mm8_one<float>(int N, int M,
|
| 195 |
+
float *x,
|
| 196 |
+
uint8_t *w, int w_stride,
|
| 197 |
+
float *mx, float *rx,
|
| 198 |
+
float *my, float *ry,
|
| 199 |
+
float *y) {
|
| 200 |
+
dim3 blockSize(1, MM8_ONE_TILE);
|
| 201 |
+
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
| 202 |
+
kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
|
| 203 |
+
N, M, x, w, w_stride,
|
| 204 |
+
mx, rx, my, ry, y);
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
__global__ void kernel_mm_one_fp16i8(
|
| 208 |
+
const int N, const int M,
|
| 209 |
+
const __half *__restrict__ const x,
|
| 210 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
| 211 |
+
const __half *__restrict__ const mx,
|
| 212 |
+
const __half *__restrict__ const rx,
|
| 213 |
+
const __half *__restrict__ const my,
|
| 214 |
+
const __half *__restrict__ const ry,
|
| 215 |
+
float *__restrict__ const y) {
|
| 216 |
+
|
| 217 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 218 |
+
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
| 219 |
+
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
| 220 |
+
|
| 221 |
+
if (k < M) {
|
| 222 |
+
float y_local = 0;
|
| 223 |
+
for (int j = j0; j < j1; ++j) {
|
| 224 |
+
y_local += __half2float(x[j]) * (
|
| 225 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
| 226 |
+
* __half2float(rx[k]) * __half2float(ry[j])
|
| 227 |
+
+ __half2float(mx[k]) + __half2float(my[j])
|
| 228 |
+
);
|
| 229 |
+
}
|
| 230 |
+
atomicAdd(&y[k], y_local);
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
template <>
|
| 235 |
+
void cuda_mm8_one<fp16>(int N, int M,
|
| 236 |
+
fp16 *x,
|
| 237 |
+
uint8_t *w, int w_stride,
|
| 238 |
+
fp16 *mx, fp16 *rx,
|
| 239 |
+
fp16 *my, fp16 *ry,
|
| 240 |
+
float *y) {
|
| 241 |
+
dim3 blockSize(1, MM8_ONE_TILE);
|
| 242 |
+
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
| 243 |
+
kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
|
| 244 |
+
N, M, cast(x), w, w_stride,
|
| 245 |
+
cast(mx), cast(rx), cast(my), cast(ry), y);
|
| 246 |
+
}
|
cuda/rwkv5.cu
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include "ATen/ATen.h"
|
| 4 |
+
typedef at::BFloat16 bf16;
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
template <typename F>
|
| 9 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
| 10 |
+
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
| 11 |
+
F *__restrict__ const _y)
|
| 12 |
+
{
|
| 13 |
+
const int b = blockIdx.x / H;
|
| 14 |
+
const int h = blockIdx.x % H;
|
| 15 |
+
const int i = threadIdx.x;
|
| 16 |
+
_w += h*_N_;
|
| 17 |
+
_u += h*_N_;
|
| 18 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
| 19 |
+
|
| 20 |
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
| 21 |
+
|
| 22 |
+
float state[_N_];
|
| 23 |
+
#pragma unroll
|
| 24 |
+
for (int j = 0; j < _N_; j++)
|
| 25 |
+
state[j] = _state[j];
|
| 26 |
+
|
| 27 |
+
__syncthreads();
|
| 28 |
+
u[i] = float(_u[i]);
|
| 29 |
+
w[i] = _w[i];
|
| 30 |
+
__syncthreads();
|
| 31 |
+
|
| 32 |
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
| 33 |
+
{
|
| 34 |
+
__syncthreads();
|
| 35 |
+
r[i] = float(_r[t]);
|
| 36 |
+
k[i] = float(_k[t]);
|
| 37 |
+
__syncthreads();
|
| 38 |
+
|
| 39 |
+
const float v = float(_v[t]);
|
| 40 |
+
float y = 0;
|
| 41 |
+
|
| 42 |
+
#pragma unroll
|
| 43 |
+
for (int j = 0; j < _N_; j+=4)
|
| 44 |
+
{
|
| 45 |
+
const float4& r_ = (float4&)(r[j]);
|
| 46 |
+
const float4& k_ = (float4&)(k[j]);
|
| 47 |
+
const float4& w_ = (float4&)(w[j]);
|
| 48 |
+
const float4& u_ = (float4&)(u[j]);
|
| 49 |
+
float4& s = (float4&)(state[j]);
|
| 50 |
+
float4 x;
|
| 51 |
+
|
| 52 |
+
x.x = k_.x * v;
|
| 53 |
+
x.y = k_.y * v;
|
| 54 |
+
x.z = k_.z * v;
|
| 55 |
+
x.w = k_.w * v;
|
| 56 |
+
|
| 57 |
+
y += r_.x * (u_.x * x.x + s.x);
|
| 58 |
+
y += r_.y * (u_.y * x.y + s.y);
|
| 59 |
+
y += r_.z * (u_.z * x.z + s.z);
|
| 60 |
+
y += r_.w * (u_.w * x.w + s.w);
|
| 61 |
+
|
| 62 |
+
s.x = s.x * w_.x + x.x;
|
| 63 |
+
s.y = s.y * w_.y + x.y;
|
| 64 |
+
s.z = s.z * w_.z + x.z;
|
| 65 |
+
s.w = s.w * w_.w + x.w;
|
| 66 |
+
}
|
| 67 |
+
_y[t] = F(y);
|
| 68 |
+
}
|
| 69 |
+
#pragma unroll
|
| 70 |
+
for (int j = 0; j < _N_; j++)
|
| 71 |
+
_state[j] = state[j];
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
| 75 |
+
{
|
| 76 |
+
assert(H*_N_ == C);
|
| 77 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 78 |
+
}
|
| 79 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
| 80 |
+
{
|
| 81 |
+
assert(H*_N_ == C);
|
| 82 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 83 |
+
}
|
| 84 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
| 85 |
+
{
|
| 86 |
+
assert(H*_N_ == C);
|
| 87 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 88 |
+
}
|
cuda/rwkv5_op.cpp
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/ATen.h"
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
typedef at::BFloat16 bf16;
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
| 9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
| 10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
| 11 |
+
|
| 12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 13 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 14 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
| 15 |
+
}
|
| 16 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 17 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 18 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
| 19 |
+
}
|
| 20 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 21 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 22 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 26 |
+
m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
|
| 27 |
+
m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
|
| 28 |
+
m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
|
| 29 |
+
}
|
| 30 |
+
TORCH_LIBRARY(rwkv5, m) {
|
| 31 |
+
m.def("forward_bf16", forward_bf16);
|
| 32 |
+
m.def("forward_fp16", forward_fp16);
|
| 33 |
+
m.def("forward_fp32", forward_fp32);
|
| 34 |
+
}
|
cuda/rwkv6.cu
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include "ATen/ATen.h"
|
| 4 |
+
typedef at::BFloat16 bf16;
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
template <typename F>
|
| 9 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
| 10 |
+
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
| 11 |
+
F *__restrict__ const _y)
|
| 12 |
+
{
|
| 13 |
+
const int b = blockIdx.x / H;
|
| 14 |
+
const int h = blockIdx.x % H;
|
| 15 |
+
const int i = threadIdx.x;
|
| 16 |
+
_u += h*_N_;
|
| 17 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
| 18 |
+
|
| 19 |
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
| 20 |
+
|
| 21 |
+
float state[_N_];
|
| 22 |
+
#pragma unroll
|
| 23 |
+
for (int j = 0; j < _N_; j++)
|
| 24 |
+
state[j] = _state[j];
|
| 25 |
+
|
| 26 |
+
__syncthreads();
|
| 27 |
+
u[i] = float(_u[i]);
|
| 28 |
+
__syncthreads();
|
| 29 |
+
|
| 30 |
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
| 31 |
+
{
|
| 32 |
+
__syncthreads();
|
| 33 |
+
w[i] = _w[t];
|
| 34 |
+
r[i] = float(_r[t]);
|
| 35 |
+
k[i] = float(_k[t]);
|
| 36 |
+
__syncthreads();
|
| 37 |
+
|
| 38 |
+
const float v = float(_v[t]);
|
| 39 |
+
float y = 0;
|
| 40 |
+
|
| 41 |
+
#pragma unroll
|
| 42 |
+
for (int j = 0; j < _N_; j+=4)
|
| 43 |
+
{
|
| 44 |
+
const float4& r_ = (float4&)(r[j]);
|
| 45 |
+
const float4& k_ = (float4&)(k[j]);
|
| 46 |
+
const float4& w_ = (float4&)(w[j]);
|
| 47 |
+
const float4& u_ = (float4&)(u[j]);
|
| 48 |
+
float4& s = (float4&)(state[j]);
|
| 49 |
+
float4 x;
|
| 50 |
+
|
| 51 |
+
x.x = k_.x * v;
|
| 52 |
+
x.y = k_.y * v;
|
| 53 |
+
x.z = k_.z * v;
|
| 54 |
+
x.w = k_.w * v;
|
| 55 |
+
|
| 56 |
+
y += r_.x * (u_.x * x.x + s.x);
|
| 57 |
+
y += r_.y * (u_.y * x.y + s.y);
|
| 58 |
+
y += r_.z * (u_.z * x.z + s.z);
|
| 59 |
+
y += r_.w * (u_.w * x.w + s.w);
|
| 60 |
+
|
| 61 |
+
s.x = s.x * w_.x + x.x;
|
| 62 |
+
s.y = s.y * w_.y + x.y;
|
| 63 |
+
s.z = s.z * w_.z + x.z;
|
| 64 |
+
s.w = s.w * w_.w + x.w;
|
| 65 |
+
}
|
| 66 |
+
_y[t] = F(y);
|
| 67 |
+
}
|
| 68 |
+
#pragma unroll
|
| 69 |
+
for (int j = 0; j < _N_; j++)
|
| 70 |
+
_state[j] = state[j];
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
| 74 |
+
{
|
| 75 |
+
assert(H*_N_ == C);
|
| 76 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 77 |
+
}
|
| 78 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
| 79 |
+
{
|
| 80 |
+
assert(H*_N_ == C);
|
| 81 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 82 |
+
}
|
| 83 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
| 84 |
+
{
|
| 85 |
+
assert(H*_N_ == C);
|
| 86 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 87 |
+
}
|
cuda/rwkv6_op.cpp
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/ATen.h"
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
typedef at::BFloat16 bf16;
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
| 9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
| 10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
| 11 |
+
|
| 12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 13 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 14 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
| 15 |
+
}
|
| 16 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 17 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 18 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
| 19 |
+
}
|
| 20 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 21 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 22 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 26 |
+
m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
|
| 27 |
+
m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
|
| 28 |
+
m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
|
| 29 |
+
}
|
| 30 |
+
TORCH_LIBRARY(rwkv6, m) {
|
| 31 |
+
m.def("forward_bf16", forward_bf16);
|
| 32 |
+
m.def("forward_fp16", forward_fp16);
|
| 33 |
+
m.def("forward_fp32", forward_fp32);
|
| 34 |
+
}
|
cuda/wrapper.cpp
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/ATen.h"
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 5 |
+
|
| 6 |
+
typedef at::Half fp16;
|
| 7 |
+
|
| 8 |
+
template <typename F>
|
| 9 |
+
void cuda_wkv_forward(int B, int T, int C,
|
| 10 |
+
float *w, float *u, F *k, F *v, F *y,
|
| 11 |
+
float *aa, float *bb, float *pp);
|
| 12 |
+
template <typename F>
|
| 13 |
+
void cuda_mm8_seq(int B, int N, int M,
|
| 14 |
+
F *x, int x_stride,
|
| 15 |
+
uint8_t *w, int w_stride,
|
| 16 |
+
F *mx, F *rx,
|
| 17 |
+
F *my, F *ry,
|
| 18 |
+
F *y, int y_stride);
|
| 19 |
+
template <typename F>
|
| 20 |
+
void cuda_mm8_one(int N, int M,
|
| 21 |
+
F *x,
|
| 22 |
+
uint8_t *w, int w_stride,
|
| 23 |
+
F *mx, F *rx,
|
| 24 |
+
F *my, F *ry,
|
| 25 |
+
float *y);
|
| 26 |
+
|
| 27 |
+
void wkv_forward(int64_t B, int64_t T, int64_t C,
|
| 28 |
+
torch::Tensor &w, torch::Tensor &u,
|
| 29 |
+
torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
| 30 |
+
torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
|
| 31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
| 32 |
+
switch (k.scalar_type()) {
|
| 33 |
+
case c10::ScalarType::Half:
|
| 34 |
+
cuda_wkv_forward(B, T, C,
|
| 35 |
+
w.data_ptr<float>(), u.data_ptr<float>(),
|
| 36 |
+
k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
|
| 37 |
+
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
| 38 |
+
break;
|
| 39 |
+
case c10::ScalarType::Float:
|
| 40 |
+
cuda_wkv_forward(B, T, C,
|
| 41 |
+
w.data_ptr<float>(), u.data_ptr<float>(),
|
| 42 |
+
k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
|
| 43 |
+
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
| 44 |
+
break;
|
| 45 |
+
default:
|
| 46 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
void mm8_seq(int64_t B, int64_t N, int64_t M,
|
| 51 |
+
torch::Tensor &x, torch::Tensor &w,
|
| 52 |
+
torch::Tensor &mx, torch::Tensor &rx,
|
| 53 |
+
torch::Tensor &my, torch::Tensor &ry,
|
| 54 |
+
torch::Tensor &y) {
|
| 55 |
+
assert(x.stride(1) == 1);
|
| 56 |
+
assert(w.stride(1) == 1);
|
| 57 |
+
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
| 58 |
+
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
| 59 |
+
assert(y.stride(1) == 1);
|
| 60 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
| 61 |
+
switch (x.scalar_type()) {
|
| 62 |
+
case c10::ScalarType::Half:
|
| 63 |
+
cuda_mm8_seq(
|
| 64 |
+
B, N, M,
|
| 65 |
+
x.data_ptr<fp16>(), x.stride(0),
|
| 66 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
| 67 |
+
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
| 68 |
+
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
| 69 |
+
y.data_ptr<fp16>(), y.stride(0));
|
| 70 |
+
break;
|
| 71 |
+
case c10::ScalarType::Float:
|
| 72 |
+
cuda_mm8_seq(
|
| 73 |
+
B, N, M,
|
| 74 |
+
x.data_ptr<float>(), x.stride(0),
|
| 75 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
| 76 |
+
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
| 77 |
+
my.data_ptr<float>(), ry.data_ptr<float>(),
|
| 78 |
+
y.data_ptr<float>(), y.stride(0));
|
| 79 |
+
break;
|
| 80 |
+
default:
|
| 81 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
void mm8_one(int64_t N, int64_t M,
|
| 85 |
+
torch::Tensor &x, torch::Tensor &w,
|
| 86 |
+
torch::Tensor &mx, torch::Tensor &rx,
|
| 87 |
+
torch::Tensor &my, torch::Tensor &ry,
|
| 88 |
+
torch::Tensor &y) {
|
| 89 |
+
assert(x.stride(0) == 1);
|
| 90 |
+
assert(w.stride(1) == 1);
|
| 91 |
+
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
| 92 |
+
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
| 93 |
+
assert(y.stride(0) == 1);
|
| 94 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
| 95 |
+
switch (x.scalar_type()) {
|
| 96 |
+
case c10::ScalarType::Half:
|
| 97 |
+
cuda_mm8_one(
|
| 98 |
+
N, M,
|
| 99 |
+
x.data_ptr<fp16>(),
|
| 100 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
| 101 |
+
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
| 102 |
+
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
| 103 |
+
y.data_ptr<float>());
|
| 104 |
+
break;
|
| 105 |
+
case c10::ScalarType::Float:
|
| 106 |
+
cuda_mm8_one(
|
| 107 |
+
N, M,
|
| 108 |
+
x.data_ptr<float>(),
|
| 109 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
| 110 |
+
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
| 111 |
+
my.data_ptr<float>(), ry.data_ptr<float>(),
|
| 112 |
+
y.data_ptr<float>());
|
| 113 |
+
break;
|
| 114 |
+
default:
|
| 115 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
using torch::Tensor;
|
| 120 |
+
|
| 121 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
| 122 |
+
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
| 123 |
+
#endif
|
| 124 |
+
|
| 125 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 126 |
+
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
| 127 |
+
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
| 128 |
+
m.def("mm8_one", &mm8_one, "mm8 one");
|
| 129 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
| 130 |
+
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
| 131 |
+
#endif
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
TORCH_LIBRARY(rwkv, m) {
|
| 135 |
+
m.def("wkv_forward", wkv_forward);
|
| 136 |
+
m.def("mm8_seq", mm8_seq);
|
| 137 |
+
m.def("mm8_one", mm8_one);
|
| 138 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
| 139 |
+
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
| 140 |
+
#endif
|
| 141 |
+
}
|
pyproject.toml
CHANGED
|
@@ -8,11 +8,13 @@ dependencies = [
|
|
| 8 |
"fastapi[standard]>=0.115.11",
|
| 9 |
"huggingface-hub>=0.29.1",
|
| 10 |
"loguru>=0.7.3",
|
|
|
|
| 11 |
"numpy>=2.2.3",
|
| 12 |
"pydantic>=2.10.6",
|
| 13 |
"pydantic-settings>=2.8.1",
|
| 14 |
"pynvml>=12.0.0",
|
| 15 |
"rwkv==0.8.28",
|
|
|
|
| 16 |
"snowflake-id>=1.0.2",
|
| 17 |
]
|
| 18 |
|
|
|
|
| 8 |
"fastapi[standard]>=0.115.11",
|
| 9 |
"huggingface-hub>=0.29.1",
|
| 10 |
"loguru>=0.7.3",
|
| 11 |
+
"ninja>=1.11.1.3",
|
| 12 |
"numpy>=2.2.3",
|
| 13 |
"pydantic>=2.10.6",
|
| 14 |
"pydantic-settings>=2.8.1",
|
| 15 |
"pynvml>=12.0.0",
|
| 16 |
"rwkv==0.8.28",
|
| 17 |
+
"setuptools>=75.8.2",
|
| 18 |
"snowflake-id>=1.0.2",
|
| 19 |
]
|
| 20 |
|
uv.lock
CHANGED
|
@@ -446,6 +446,30 @@ wheels = [
|
|
| 446 |
{ url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 },
|
| 447 |
]
|
| 448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
[[package]]
|
| 450 |
name = "numpy"
|
| 451 |
version = "2.2.3"
|
|
@@ -915,11 +939,13 @@ dependencies = [
|
|
| 915 |
{ name = "fastapi", extra = ["standard"] },
|
| 916 |
{ name = "huggingface-hub" },
|
| 917 |
{ name = "loguru" },
|
|
|
|
| 918 |
{ name = "numpy" },
|
| 919 |
{ name = "pydantic" },
|
| 920 |
{ name = "pydantic-settings" },
|
| 921 |
{ name = "pynvml" },
|
| 922 |
{ name = "rwkv" },
|
|
|
|
| 923 |
{ name = "snowflake-id" },
|
| 924 |
]
|
| 925 |
|
|
@@ -940,11 +966,13 @@ requires-dist = [
|
|
| 940 |
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.11" },
|
| 941 |
{ name = "huggingface-hub", specifier = ">=0.29.1" },
|
| 942 |
{ name = "loguru", specifier = ">=0.7.3" },
|
|
|
|
| 943 |
{ name = "numpy", specifier = ">=2.2.3" },
|
| 944 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
| 945 |
{ name = "pydantic-settings", specifier = ">=2.8.1" },
|
| 946 |
{ name = "pynvml", specifier = ">=12.0.0" },
|
| 947 |
{ name = "rwkv", specifier = "==0.8.28" },
|
|
|
|
| 948 |
{ name = "snowflake-id", specifier = ">=1.0.2" },
|
| 949 |
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "rwkv-hf-space", extra = "cpu" } },
|
| 950 |
{ name = "torch", marker = "extra == 'cu113'", index = "https://download.pytorch.org/whl/cu113", conflict = { package = "rwkv-hf-space", extra = "cu113" } },
|
|
|
|
| 446 |
{ url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 },
|
| 447 |
]
|
| 448 |
|
| 449 |
+
[[package]]
|
| 450 |
+
name = "ninja"
|
| 451 |
+
version = "1.11.1.3"
|
| 452 |
+
source = { registry = "https://pypi.org/simple" }
|
| 453 |
+
sdist = { url = "https://files.pythonhosted.org/packages/bd/8f/21a2701f95b7d0d5137736561b3427ece0c4a1e085d4a223b92d16ab7d8b/ninja-1.11.1.3.tar.gz", hash = "sha256:edfa0d2e9d7ead1635b03e40a32ad56cc8f56798b6e2e9848d8300b174897076", size = 129532 }
|
| 454 |
+
wheels = [
|
| 455 |
+
{ url = "https://files.pythonhosted.org/packages/ea/ba/0069cd4a83d68f7b0308be70e219b15d675e50c8ea28763a3f0373c45bfc/ninja-1.11.1.3-py3-none-macosx_10_9_universal2.whl", hash = "sha256:2b4879ea3f1169f3d855182c57dcc84d1b5048628c8b7be0d702b81882a37237", size = 279132 },
|
| 456 |
+
{ url = "https://files.pythonhosted.org/packages/72/6b/3805be87df8417a0c7b21078c8045f2a1e59b34f371bfe4cb4fb0d6df7f2/ninja-1.11.1.3-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:bc3ebc8b2e47716149f3541742b5cd8e0b08f51013b825c05baca3e34854370d", size = 472101 },
|
| 457 |
+
{ url = "https://files.pythonhosted.org/packages/6b/35/a8e38d54768e67324e365e2a41162be298f51ec93e6bd4b18d237d7250d8/ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a27e78ca71316c8654965ee94b286a98c83877bfebe2607db96897bbfe458af0", size = 422884 },
|
| 458 |
+
{ url = "https://files.pythonhosted.org/packages/2f/99/7996457319e139c02697fb2aa28e42fe32bb0752cef492edc69d56a3552e/ninja-1.11.1.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2883ea46b3c5079074f56820f9989c6261fcc6fd873d914ee49010ecf283c3b2", size = 157046 },
|
| 459 |
+
{ url = "https://files.pythonhosted.org/packages/6d/8b/93f38e5cddf76ccfdab70946515b554f25d2b4c95ef9b2f9cfbc43fa7cc1/ninja-1.11.1.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c4bdb9fd2d0c06501ae15abfd23407660e95659e384acd36e013b6dd7d8a8e4", size = 180014 },
|
| 460 |
+
{ url = "https://files.pythonhosted.org/packages/7d/1d/713884d0fa3c972164f69d552e0701d30e2bf25eba9ef160bfb3dc69926a/ninja-1.11.1.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:114ed5c61c8474df6a69ab89097a20749b769e2c219a452cb2fadc49b0d581b0", size = 157098 },
|
| 461 |
+
{ url = "https://files.pythonhosted.org/packages/c7/22/ecb0f70e77c9e22ee250aa717a608a142756833a34d43943d7d658ee0e56/ninja-1.11.1.3-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7fa2247fce98f683bc712562d82b22b8a0a5c000738a13147ca2d1b68c122298", size = 130089 },
|
| 462 |
+
{ url = "https://files.pythonhosted.org/packages/ec/a6/3ee846c20ab6ad95b90c5c8703c76cb1f39cc8ce2d1ae468956e3b1b2581/ninja-1.11.1.3-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:a38c6c6c8032bed68b70c3b065d944c35e9f903342875d3a3218c1607987077c", size = 372508 },
|
| 463 |
+
{ url = "https://files.pythonhosted.org/packages/95/0d/aa44abe4141f29148ce671ac8c92045878906b18691c6f87a29711c2ff1c/ninja-1.11.1.3-py3-none-musllinux_1_1_i686.whl", hash = "sha256:56ada5d33b8741d298836644042faddebc83ee669782d661e21563034beb5aba", size = 419369 },
|
| 464 |
+
{ url = "https://files.pythonhosted.org/packages/f7/ec/48bf5105568ac9bd2016b701777bdd5000cc09a14ac837fef9f15e8d634e/ninja-1.11.1.3-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:53409151da081f3c198bb0bfc220a7f4e821e022c5b7d29719adda892ddb31bb", size = 420304 },
|
| 465 |
+
{ url = "https://files.pythonhosted.org/packages/18/e5/69df63976cf971a03379899f8520a036c9dbab26330b37197512aed5b3df/ninja-1.11.1.3-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:1ad2112c2b0159ed7c4ae3731595191b1546ba62316fc40808edecd0306fefa3", size = 416056 },
|
| 466 |
+
{ url = "https://files.pythonhosted.org/packages/6f/4f/bdb401af7ed0e24a3fef058e13a149f2de1ce4b176699076993615d55610/ninja-1.11.1.3-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:28aea3c1c280cba95b8608d50797169f3a34280e3e9a6379b6e340f0c9eaeeb0", size = 379725 },
|
| 467 |
+
{ url = "https://files.pythonhosted.org/packages/bd/68/05e7863bf13128c61652eeb3ec7096c3d3a602f32f31752dbfb034e3fa07/ninja-1.11.1.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b6966f83064a88a51693073eea3decd47e08c3965241e09578ef7aa3a7738329", size = 434881 },
|
| 468 |
+
{ url = "https://files.pythonhosted.org/packages/bd/ad/edc0d1efe77f29f45bbca2e1dab07ef597f61a88de6e4bccffc0aec2256c/ninja-1.11.1.3-py3-none-win32.whl", hash = "sha256:a4a3b71490557e18c010cbb26bd1ea9a0c32ee67e8f105e9731515b6e0af792e", size = 255988 },
|
| 469 |
+
{ url = "https://files.pythonhosted.org/packages/03/93/09a9f7672b4f97438aca6217ac54212a63273f1cd3b46b731d0bb22c53e7/ninja-1.11.1.3-py3-none-win_amd64.whl", hash = "sha256:04d48d14ea7ba11951c156599ab526bdda575450797ff57c6fdf99b2554d09c7", size = 296502 },
|
| 470 |
+
{ url = "https://files.pythonhosted.org/packages/d9/9d/0cc1e82849070ff3cbee69f326cb48a839407bcd15d8844443c30a5e7509/ninja-1.11.1.3-py3-none-win_arm64.whl", hash = "sha256:17978ad611d8ead578d83637f5ae80c2261b033db0b493a7ce94f88623f29e1b", size = 270571 },
|
| 471 |
+
]
|
| 472 |
+
|
| 473 |
[[package]]
|
| 474 |
name = "numpy"
|
| 475 |
version = "2.2.3"
|
|
|
|
| 939 |
{ name = "fastapi", extra = ["standard"] },
|
| 940 |
{ name = "huggingface-hub" },
|
| 941 |
{ name = "loguru" },
|
| 942 |
+
{ name = "ninja" },
|
| 943 |
{ name = "numpy" },
|
| 944 |
{ name = "pydantic" },
|
| 945 |
{ name = "pydantic-settings" },
|
| 946 |
{ name = "pynvml" },
|
| 947 |
{ name = "rwkv" },
|
| 948 |
+
{ name = "setuptools" },
|
| 949 |
{ name = "snowflake-id" },
|
| 950 |
]
|
| 951 |
|
|
|
|
| 966 |
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.11" },
|
| 967 |
{ name = "huggingface-hub", specifier = ">=0.29.1" },
|
| 968 |
{ name = "loguru", specifier = ">=0.7.3" },
|
| 969 |
+
{ name = "ninja", specifier = ">=1.11.1.3" },
|
| 970 |
{ name = "numpy", specifier = ">=2.2.3" },
|
| 971 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
| 972 |
{ name = "pydantic-settings", specifier = ">=2.8.1" },
|
| 973 |
{ name = "pynvml", specifier = ">=12.0.0" },
|
| 974 |
{ name = "rwkv", specifier = "==0.8.28" },
|
| 975 |
+
{ name = "setuptools", specifier = ">=75.8.2" },
|
| 976 |
{ name = "snowflake-id", specifier = ">=1.0.2" },
|
| 977 |
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "rwkv-hf-space", extra = "cpu" } },
|
| 978 |
{ name = "torch", marker = "extra == 'cu113'", index = "https://download.pytorch.org/whl/cu113", conflict = { package = "rwkv-hf-space", extra = "cu113" } },
|