drbh commited on
Commit
79aac9d
·
0 Parent(s):

feat: impl residual rms kernel repo

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ py_example
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ## Residual RMS for ROCM
6
+
7
+ Residual RMS kernels from [residual_rms](https://github.com/huggingface/hf-rocm-kernels).
8
+
9
+ # Development
10
+
11
+ This kernel can be built using the the [HF Kernel Builder](https://github.com/huggingface/kernel-builder) using the following the commands.
12
+
13
+ ## Build
14
+
15
+ ```bash
16
+ nix build .#bundle -L
17
+ ```
18
+
19
+ ### Dev shell
20
+
21
+ ```bash
22
+ nix develop -L
23
+ pytest tests
24
+ ```
25
+
26
+ ## Publish
27
+
28
+ ```bash
29
+ git remote add origin [email protected]:kernels-community/residual_rms
30
+ git push -u origin main
31
+ ```
build.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ version = "0.0.1"
3
+
4
+ [torch]
5
+ name = "residual_rms"
6
+ src = [
7
+ "ext-torch/registration.h",
8
+ "ext-torch/torch_binding.cpp",
9
+ "ext-torch/torch_binding.h",
10
+ ]
11
+ include = ["."]
12
+ pyroot = "ext-torch"
13
+ pyext = ["py", "json"]
14
+
15
+ [kernel.residual_rms]
16
+ capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
17
+ src = [
18
+ "residual_rms/residual_rms_dispatch.cu",
19
+ "residual_rms/compat.h",
20
+ ]
21
+ depends = ["torch"]
ext-torch/registration.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <Python.h>
4
+
5
+ #define _CONCAT(A, B) A##B
6
+ #define CONCAT(A, B) _CONCAT(A, B)
7
+
8
+ #define _STRINGIFY(A) #A
9
+ #define STRINGIFY(A) _STRINGIFY(A)
10
+
11
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12
+ // could be a macro instead of a literal token.
13
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14
+
15
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16
+ // could be a macro instead of a literal token.
17
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19
+
20
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21
+ // via python's import statement.
22
+ #define REGISTER_EXTENSION(NAME) \
23
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
26
+ return PyModule_Create(&module); \
27
+ }
ext-torch/residual_rms/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ try:
4
+ from ._ops import ops
5
+ except ImportError as e:
6
+ # Fallback for local development.
7
+ try:
8
+ import _residual_rms
9
+
10
+ ops = torch.ops._residual_rms
11
+ except ImportError:
12
+ raise e
13
+
14
+ def residual_rms(out: torch.Tensor, x: torch.Tensor) -> None:
15
+ ops.residual_rms(out, x)
16
+ return out
ext-torch/torch_binding.cpp ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ // Increment a tensor by 1.
8
+ ops.def("increment(Tensor x) -> ()");
9
+ ops.impl("increment", torch::kCUDA, &increment);
10
+
11
+ // Compute the residual root mean square.
12
+ ops.def("residual_rms(Tensor input, Tensor residual, Tensor weight, Tensor output, float epsilon, float scale, int mode, int num_threads) -> ()");
13
+ ops.impl("residual_rms", torch::kCUDA, &residual_rms);
14
+ }
15
+
16
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
ext-torch/torch_binding.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <optional>
4
+ #include <torch/library.h>
5
+
6
+ #include <vector>
7
+
8
+ void increment(torch::Tensor &x);
9
+
10
+ void residual_rms(torch::Tensor &input, torch::Tensor &residual,
11
+ torch::Tensor &weight, torch::Tensor &output, double epsilon,
12
+ double scale, int64_t mode, int64_t num_threads);
flake.lock ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1733328505,
6
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs"
41
+ },
42
+ "locked": {
43
+ "lastModified": 1738315861,
44
+ "narHash": "sha256-QPWRaIPAMmQANuAOaZIKzh1e69OG8zBWGg+swESEajw=",
45
+ "ref": "refs/heads/main",
46
+ "rev": "eabeadcedba5dcef2a562b8f1ed5ec1feb485496",
47
+ "revCount": 72,
48
+ "type": "git",
49
+ "url": "ssh://[email protected]/huggingface/kernel-builder"
50
+ },
51
+ "original": {
52
+ "type": "git",
53
+ "url": "ssh://[email protected]/huggingface/kernel-builder"
54
+ }
55
+ },
56
+ "nixpkgs": {
57
+ "locked": {
58
+ "lastModified": 1738247409,
59
+ "narHash": "sha256-F72dKl9Na6/2N+garOm9qCXPa92GzR8eYSuDra6kbjY=",
60
+ "owner": "danieldk",
61
+ "repo": "nixpkgs",
62
+ "rev": "358f57074b70e3ee9e1dc118151a4f6f81fcd3bb",
63
+ "type": "github"
64
+ },
65
+ "original": {
66
+ "owner": "danieldk",
67
+ "ref": "cuda-12.6-for-kernel-builder",
68
+ "repo": "nixpkgs",
69
+ "type": "github"
70
+ }
71
+ },
72
+ "root": {
73
+ "inputs": {
74
+ "kernel-builder": "kernel-builder"
75
+ }
76
+ },
77
+ "systems": {
78
+ "locked": {
79
+ "lastModified": 1681028828,
80
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
81
+ "owner": "nix-systems",
82
+ "repo": "default",
83
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
84
+ "type": "github"
85
+ },
86
+ "original": {
87
+ "owner": "nix-systems",
88
+ "repo": "default",
89
+ "type": "github"
90
+ }
91
+ }
92
+ },
93
+ "root": "root",
94
+ "version": 7
95
+ }
flake.nix ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for rocm residual rms kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "git+ssh://[email protected]/huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
14
+ }
residual_rms/compat.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <hip/hip_runtime.h>
4
+
5
+ #define WARP_SIZE 32
residual_rms/residual_rms_dispatch.cu ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include <hip/hip_runtime.h>
4
+
5
+ #include "op_src/residual_rms/residual_rms_v0.cu"
6
+ #include "op_src/residual_rms/residual_rms_v1.cu"
7
+ #include "op_src/residual_rms/residual_rms_v2.cu"
8
+ #include "op_src/residual_rms/residual_rms_v3.cu"
9
+ #include "op_src/residual_rms/residual_rms_v4.cu"
10
+
11
+ void residual_rms(torch::Tensor& input, // Shape: [m, n] / Layout: row-major / Dtype: fp16
12
+ torch::Tensor& residual, // Shape: [m, n] / Layout: row-major / Dtype: fp16
13
+ torch::Tensor& weight, // Shape: [m, ] / Layout: row-major / Dtype: fp16
14
+ torch::Tensor& output, // Shape: [m, n] / Layout: row-major / Dtype: fp8
15
+ double epsilon, double scale, int64_t mode,
16
+ int64_t num_threads) { // TODO: add fp16 output mode
17
+
18
+ // Retrieve shapes
19
+ const int rows = input.size(0);
20
+ const int cols = input.size(1);
21
+ // Activate device guard
22
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
23
+
24
+ // Prepare kernel launch arguments
25
+ dim3 grid(rows);
26
+ dim3 block(num_threads);
27
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
28
+
29
+ // Launch kernel
30
+ switch (mode) {
31
+ case 1:
32
+ LAUNCH_RESIDUAL_RMS_V1;
33
+ break;
34
+ case 2:
35
+ LAUNCH_RESIDUAL_RMS_V2;
36
+ break;
37
+ case 3:
38
+ LAUNCH_RESIDUAL_RMS_V3;
39
+ break;
40
+ case 4:
41
+ LAUNCH_RESIDUAL_RMS_V4;
42
+ break;
43
+ default:
44
+ LAUNCH_RESIDUAL_RMS_V0;
45
+ break;
46
+ }
47
+ }
48
+
49
+ /*
50
+ Versions:
51
+ 0. non-vectorized version
52
+ 1. vectorizes loads and stores
53
+ 2. simplified indexing
54
+ 3. added packed conversion
55
+ 4. using packed types everywhere and custom ASM for residual connection and variance
56
+ */
residual_rms/residual_rms_v0.cu ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+
3
+ #include <hip/hip_bf16.h>
4
+ #include <hip/hip_fp16.h>
5
+ #include <hipcub/util_type.hpp>
6
+ #include <hipcub/hipcub.hpp>
7
+ #include <hip/hip_fp8.h>
8
+
9
+ #include "utils/macros.h"
10
+
11
+ __global__ void _residual_rms_v0(const half* __restrict__ input, half* __restrict__ residual,
12
+ const half* __restrict__ weight, __hip_fp8_storage_t* __restrict__ output,
13
+ const float epsilon, const float scale, const int cols) {
14
+ // Advance pointers according to the position of the thread in the grid
15
+ input += blockIdx.x * cols;
16
+ residual += blockIdx.x * cols;
17
+ output += blockIdx.x * cols;
18
+
19
+ // Residual connection: inplace add of input to residual, accumulate norm along the way
20
+ float variance = 0.0f;
21
+
22
+ for (int i = threadIdx.x; i < cols; i += blockDim.x) {
23
+ half z = input[i];
24
+ z += residual[i];
25
+ float x = (float)z;
26
+ variance += (x * x);
27
+ residual[i] = z;
28
+ }
29
+ variance /= cols;
30
+
31
+ // Block reduce to compute the total norm
32
+ __shared__ float shared_normalizer;
33
+ using BlockReduce = hipcub::BlockReduce<float, 1024>;
34
+ __shared__ typename BlockReduce::TempStorage reduceStore;
35
+
36
+ variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
37
+ if (threadIdx.x == 0) {
38
+ shared_normalizer = rsqrtf(variance + epsilon);
39
+ }
40
+ __syncthreads();
41
+
42
+ // Normalize and convert
43
+ for (int idx = threadIdx.x; idx < cols; idx += blockDim.x) {
44
+ float x = (float)residual[idx];
45
+ half y = (half)(x * shared_normalizer);
46
+ y = (y * weight[idx]);
47
+ x = (float)y;
48
+ x *= scale;
49
+ FP8_CLAMP(x, float);
50
+ output[idx] = __hip_cvt_float_to_fp8(x, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
51
+ }
52
+ }
53
+
54
+ #define LAUNCH_RESIDUAL_RMS_V0 \
55
+ (_residual_rms_v0<<<grid, block, 0, stream>>>((half*)input.data_ptr(), (half*)residual.data_ptr(), \
56
+ (half*)weight.data_ptr(), (__hip_fp8_storage_t*)output.data_ptr(), \
57
+ epsilon, scale, cols))
residual_rms/residual_rms_v1.cu ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+
3
+ #include <hip/hip_bf16.h>
4
+ #include <hip/hip_fp16.h>
5
+ #include <hipcub/util_type.hpp>
6
+ #include <hipcub/hipcub.hpp>
7
+ #include <hip/hip_fp8.h>
8
+
9
+ #include "utils/macros.h"
10
+
11
+ #define WPT 8 // WorkPerThreads
12
+
13
+ __global__ void _residual_rms_v1(const half* __restrict__ input, half* __restrict__ residual,
14
+ const half* __restrict__ weight, __hip_fp8_storage_t* __restrict__ output,
15
+ const float epsilon, const float scale, const int cols) {
16
+ // Advance pointers according to the position of the thread in the grid
17
+ input += blockIdx.x * cols;
18
+ residual += blockIdx.x * cols;
19
+ output += blockIdx.x * cols;
20
+
21
+ // Residual connection: inplace add of input to residual, accumulate norm along the way
22
+ float variance = 0.0f;
23
+ float fp32_residual;
24
+ half input_buffer[WPT];
25
+ half residual_buffer[WPT];
26
+
27
+ for (int i = WPT * threadIdx.x; i < cols; i += WPT * blockDim.x) {
28
+ // Load data using 128-bits loads
29
+ #pragma unroll
30
+ for (int j = 0; j < WPT; j++) {
31
+ input_buffer[j] = input[i + j];
32
+ }
33
+ #pragma unroll
34
+ for (int j = 0; j < WPT; j++) {
35
+ residual_buffer[j] = residual[i + j];
36
+ }
37
+
38
+ // Add everything in the residual buffer and accumulate variance
39
+ #pragma unroll
40
+ for (int j = 0; j < WPT; j++) {
41
+ residual_buffer[j] += input_buffer[j];
42
+ fp32_residual = (float)residual_buffer[j];
43
+ variance += fp32_residual * fp32_residual;
44
+ }
45
+
46
+ // 128-bits store
47
+ #pragma unroll
48
+ for (int j = 0; j < WPT; j++) {
49
+ residual[i + j] = residual_buffer[j];
50
+ }
51
+ }
52
+ variance /= cols;
53
+
54
+ // Block reduce to compute the total norm
55
+ __shared__ float shared_normalizer;
56
+ using BlockReduce = hipcub::BlockReduce<float, 1024>;
57
+ __shared__ typename BlockReduce::TempStorage reduceStore;
58
+
59
+ variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
60
+ if (threadIdx.x == 0) {
61
+ shared_normalizer = rsqrtf(variance + epsilon);
62
+ }
63
+ __syncthreads();
64
+
65
+ // Normalize and convert
66
+ float tmp_float;
67
+ half residual_buffer_[WPT];
68
+ half weight_buffer[WPT];
69
+ __hip_fp8_storage_t fp8_buffer[WPT];
70
+
71
+ for (int i = WPT * threadIdx.x; i < cols; i += WPT * blockDim.x) {
72
+ // 128-bits loads
73
+ #pragma unroll
74
+ for (int j = 0; j < WPT; j++) {
75
+ residual_buffer_[j] = residual[i + j];
76
+ }
77
+ #pragma unroll
78
+ for (int j = 0; j < WPT; j++) {
79
+ weight_buffer[j] = weight[i + j];
80
+ }
81
+
82
+ // Compute and fill buffer
83
+ #pragma unroll
84
+ for (int j = 0; j < WPT; j++) {
85
+ tmp_float = (float)residual_buffer_[j] * shared_normalizer;
86
+ tmp_float = (float)((half)(tmp_float)*weight_buffer[j]);
87
+ tmp_float *= scale;
88
+ FP8_CLAMP(tmp_float, float);
89
+ fp8_buffer[j] = __hip_cvt_float_to_fp8(tmp_float, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
90
+ }
91
+
92
+ // 64b store
93
+ #pragma unroll
94
+ for (int j = 0; j < WPT; j++) {
95
+ output[i + j] = fp8_buffer[j];
96
+ }
97
+ }
98
+ }
99
+
100
+ #define LAUNCH_RESIDUAL_RMS_V1 \
101
+ (_residual_rms_v1<<<grid, block, 0, stream>>>((half*)input.data_ptr(), (half*)residual.data_ptr(), \
102
+ (half*)weight.data_ptr(), (__hip_fp8_storage_t*)output.data_ptr(), \
103
+ epsilon, scale, cols))
residual_rms/residual_rms_v2.cu ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+
3
+ #include <hip/hip_bf16.h>
4
+ #include <hip/hip_fp16.h>
5
+ #include <hipcub/util_type.hpp>
6
+ #include <hipcub/hipcub.hpp>
7
+ #include <hip/hip_fp8.h>
8
+
9
+ #include "utils/macros.h"
10
+
11
+ #define WPT 8 // WorkPerThreads
12
+ #define CDIV(a, b) ((a + b - 1) / (b)) // Ceiling division
13
+
14
+ __global__ void _residual_rms_v2(const half* __restrict__ input, half* __restrict__ residual,
15
+ const half* __restrict__ weight, __hip_fp8_storage_t* __restrict__ output,
16
+ const float epsilon, const float scale, const int cols) {
17
+ // Advance pointers according to the position of the thread in the grid
18
+ input += blockIdx.x * cols + WPT * threadIdx.x;
19
+ residual += blockIdx.x * cols + WPT * threadIdx.x;
20
+ weight += WPT * threadIdx.x;
21
+ output += blockIdx.x * cols + WPT * threadIdx.x;
22
+ half* residual_start = residual;
23
+
24
+ // Residual connection: inplace add of input to residual, accumulate norm along the way
25
+ float variance = 0.0f;
26
+ float fp32_residual;
27
+ half input_buffer[WPT];
28
+ half residual_buffer[WPT];
29
+
30
+ const int loop_stride = WPT * blockDim.x;
31
+ const int iterations = CDIV(cols - WPT * threadIdx.x, loop_stride);
32
+ for (int i = 0; i < iterations; i++) {
33
+ // Load data using 128-bits loads
34
+ #pragma unroll
35
+ for (int j = 0; j < WPT; j++) {
36
+ input_buffer[j] = input[j];
37
+ }
38
+ #pragma unroll
39
+ for (int j = 0; j < WPT; j++) {
40
+ residual_buffer[j] = residual[j];
41
+ }
42
+
43
+ // Add everything in the residual buffer and accumulate variance
44
+ #pragma unroll
45
+ for (int j = 0; j < WPT; j++) {
46
+ residual_buffer[j] += input_buffer[j];
47
+ fp32_residual = (float)residual_buffer[j];
48
+ variance += fp32_residual * fp32_residual;
49
+ }
50
+
51
+ // 128-bits store
52
+ #pragma unroll
53
+ for (int j = 0; j < WPT; j++) {
54
+ residual[j] = residual_buffer[j];
55
+ }
56
+
57
+ // Advance pointers
58
+ input += loop_stride;
59
+ residual += loop_stride;
60
+ }
61
+ variance /= cols;
62
+
63
+ // Block reduce to compute the total norm
64
+ __shared__ float shared_normalizer;
65
+ using BlockReduce = hipcub::BlockReduce<float, 1024>;
66
+ __shared__ typename BlockReduce::TempStorage reduceStore;
67
+
68
+ variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
69
+ if (threadIdx.x == 0) {
70
+ shared_normalizer = rsqrtf(variance + epsilon);
71
+ }
72
+ __syncthreads();
73
+
74
+ // Normalize and convert
75
+ float tmp_float;
76
+ half residual_buffer_[WPT];
77
+ half weight_buffer[WPT];
78
+ __hip_fp8_storage_t fp8_buffer[WPT];
79
+
80
+ residual = residual_start;
81
+ for (int i = 0; i < iterations; i++) {
82
+ // 128-bits loads
83
+ #pragma unroll
84
+ for (int j = 0; j < WPT; j++) {
85
+ residual_buffer_[j] = residual[j];
86
+ }
87
+ #pragma unroll
88
+ for (int j = 0; j < WPT; j++) {
89
+ weight_buffer[j] = weight[j];
90
+ }
91
+
92
+ // Compute and fill buffer
93
+ #pragma unroll
94
+ for (int j = 0; j < WPT; j++) {
95
+ tmp_float = (float)residual_buffer_[j] * shared_normalizer;
96
+ tmp_float = (float)((half)(tmp_float)*weight_buffer[j]);
97
+ tmp_float *= scale;
98
+ FP8_CLAMP(tmp_float, float);
99
+ fp8_buffer[j] = __hip_cvt_float_to_fp8(tmp_float, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
100
+ }
101
+
102
+ // 64b store
103
+ #pragma unroll
104
+ for (int j = 0; j < WPT; j++) {
105
+ output[j] = fp8_buffer[j];
106
+ }
107
+
108
+ // Advance pointers
109
+ residual += loop_stride;
110
+ weight += loop_stride;
111
+ output += loop_stride;
112
+ }
113
+ }
114
+
115
+ #define LAUNCH_RESIDUAL_RMS_V2 \
116
+ (_residual_rms_v2<<<grid, block, 0, stream>>>((half*)input.data_ptr(), (half*)residual.data_ptr(), \
117
+ (half*)weight.data_ptr(), (__hip_fp8_storage_t*)output.data_ptr(), \
118
+ epsilon, scale, cols))
residual_rms/residual_rms_v3.cu ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+
3
+ #include <hip/hip_bf16.h>
4
+ #include <hip/hip_fp16.h>
5
+ #include <hipcub/util_type.hpp>
6
+ #include <hipcub/hipcub.hpp>
7
+ #include <hip/hip_fp8.h>
8
+
9
+ #include "utils/macros.h"
10
+
11
+ #define WPT 8 // WorkPerThreads
12
+ #define CDIV(a, b) ((a + b - 1) / (b)) // Ceiling division
13
+
14
+ __global__ void _residual_rms_v3(const half* __restrict__ input, half* __restrict__ residual,
15
+ const half* __restrict__ weight, __hip_fp8x2_storage_t* __restrict__ output,
16
+ const float epsilon, const float scale, const int cols) {
17
+ // Advance pointers according to the position of the thread in the grid
18
+ input += blockIdx.x * cols + WPT * threadIdx.x;
19
+ residual += blockIdx.x * cols + WPT * threadIdx.x;
20
+ weight += WPT * threadIdx.x;
21
+ output += (blockIdx.x * cols + WPT * threadIdx.x) / 2;
22
+ half* residual_start = residual;
23
+
24
+ // Residual connection: inplace add of input to residual, accumulate norm along the way
25
+ float variance = 0.0f;
26
+ float fp32_residual;
27
+ half input_buffer[WPT];
28
+ half residual_buffer[WPT];
29
+
30
+ const int loop_stride = WPT * blockDim.x;
31
+ const int iterations = CDIV(cols - WPT * threadIdx.x, loop_stride);
32
+ for (int i = 0; i < iterations; i++) {
33
+ // Load data using 128-bits loads
34
+ #pragma unroll
35
+ for (int j = 0; j < WPT; j++) {
36
+ input_buffer[j] = input[j];
37
+ }
38
+ #pragma unroll
39
+ for (int j = 0; j < WPT; j++) {
40
+ residual_buffer[j] = residual[j];
41
+ }
42
+
43
+ // Add everything in the residual buffer and accumulate variance
44
+ #pragma unroll
45
+ for (int j = 0; j < WPT; j++) {
46
+ residual_buffer[j] += input_buffer[j];
47
+ fp32_residual = (float)residual_buffer[j];
48
+ variance += fp32_residual * fp32_residual;
49
+ }
50
+
51
+ // 128-bits store
52
+ #pragma unroll
53
+ for (int j = 0; j < WPT; j++) {
54
+ residual[j] = residual_buffer[j];
55
+ }
56
+
57
+ // Advance pointers
58
+ input += loop_stride;
59
+ residual += loop_stride;
60
+ }
61
+ variance /= cols;
62
+
63
+ // Block reduce to compute the total norm
64
+ __shared__ float shared_normalizer;
65
+ using BlockReduce = hipcub::BlockReduce<float, 1024>;
66
+ __shared__ typename BlockReduce::TempStorage reduceStore;
67
+
68
+ variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
69
+ if (threadIdx.x == 0) {
70
+ shared_normalizer = rsqrtf(variance + epsilon);
71
+ }
72
+ __syncthreads();
73
+
74
+ // Normalize and convert
75
+ float2 tmp_float2;
76
+ half residual_buffer_[WPT];
77
+ half weight_buffer[WPT];
78
+ __hip_fp8x2_storage_t fp8x2_buffer[WPT / 2];
79
+
80
+ residual = residual_start;
81
+ for (int i = 0; i < iterations; i++) {
82
+ // 128-bits loads
83
+ #pragma unroll
84
+ for (int j = 0; j < WPT; j++) {
85
+ residual_buffer_[j] = residual[j];
86
+ }
87
+ #pragma unroll
88
+ for (int j = 0; j < WPT; j++) {
89
+ weight_buffer[j] = weight[j];
90
+ }
91
+
92
+ // Compute and fill buffer
93
+ #pragma unroll
94
+ for (int j = 0; j < WPT / 2; j++) {
95
+ // .x
96
+ tmp_float2.x = (float)residual_buffer_[2 * j] * shared_normalizer;
97
+ tmp_float2.x = (float)((half)(tmp_float2.x) * weight_buffer[2 * j]);
98
+ tmp_float2.x *= scale;
99
+ FP8_CLAMP(tmp_float2.x, float);
100
+ // .y
101
+ tmp_float2.y = (float)residual_buffer_[2 * j + 1] * shared_normalizer;
102
+ tmp_float2.y = (float)((half)(tmp_float2.y) * weight_buffer[2 * j + 1]);
103
+ tmp_float2.y *= scale;
104
+ FP8_CLAMP(tmp_float2.y, float);
105
+ // convert
106
+ fp8x2_buffer[j] = __hip_cvt_float2_to_fp8x2(tmp_float2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
107
+ }
108
+
109
+ // 64b store
110
+ #pragma unroll
111
+ for (int j = 0; j < WPT / 2; j++) {
112
+ output[j] = fp8x2_buffer[j];
113
+ }
114
+
115
+ // Advance pointers
116
+ residual += loop_stride;
117
+ weight += loop_stride;
118
+ output += loop_stride / 2;
119
+ }
120
+ }
121
+
122
+ #define LAUNCH_RESIDUAL_RMS_V3 \
123
+ (_residual_rms_v3<<<grid, block, 0, stream>>>((half*)input.data_ptr(), (half*)residual.data_ptr(), \
124
+ (half*)weight.data_ptr(), (__hip_fp8x2_storage_t*)output.data_ptr(), \
125
+ epsilon, scale, cols))
residual_rms/residual_rms_v4.cu ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+
3
+ #include <hip/hip_bf16.h>
4
+ #include <hip/hip_fp16.h>
5
+ #include <hipcub/util_type.hpp>
6
+ #include <hipcub/hipcub.hpp>
7
+ #include <hip/hip_fp8.h>
8
+
9
+ #include "utils/macros.h"
10
+
11
+ #define WPT 8 // WorkPerThreads
12
+ #define CDIV(a, b) ((a + b - 1) / (b)) // Ceiling division
13
+
14
+ __global__ void _residual_rms_v4(const __half2* __restrict__ input, __half2* __restrict__ residual,
15
+ const __half2* __restrict__ weight, __hip_fp8x2_storage_t* __restrict__ output,
16
+ const float epsilon, const float scale, const int cols) {
17
+ // Advance pointers according to the position of the thread in the grid
18
+ input += (blockIdx.x * cols + WPT * threadIdx.x) / 2;
19
+ residual += (blockIdx.x * cols + WPT * threadIdx.x) / 2;
20
+ weight += (WPT * threadIdx.x) / 2;
21
+ output += (blockIdx.x * cols + WPT * threadIdx.x) / 2;
22
+
23
+ // Residual connection: inplace add of input to residual, accumulate norm along the way
24
+ float variance = 0.0f;
25
+ float fp32_residual;
26
+ __half2 input_buffer[WPT / 2];
27
+ __half2 residual_buffer[WPT / 2];
28
+
29
+ const int loop_stride = blockDim.x * (WPT / 2);
30
+ const int iterations = CDIV(cols - WPT * threadIdx.x, 2 * loop_stride);
31
+ for (int i = 0; i < iterations; i++) {
32
+ // Load data using 128-bits loads
33
+ #pragma unroll
34
+ for (int j = 0; j < WPT / 2; j++) {
35
+ input_buffer[j] = input[j];
36
+ }
37
+ #pragma unroll
38
+ for (int j = 0; j < WPT / 2; j++) {
39
+ residual_buffer[j] = residual[j];
40
+ }
41
+
42
+ // Residual connection and variance accumulation
43
+ #pragma unroll
44
+ for (int j = 0; j < WPT / 2; j++) {
45
+ asm volatile(
46
+ "V_PK_ADD_F16 %0, %2, %3\n\t"
47
+ "V_DOT2C_F32_F16 %1, %2, %2"
48
+ : "=v"(residual_buffer[j]), "=v"(variance)
49
+ : "0"(residual_buffer[j]), "v"(input_buffer[j]));
50
+ }
51
+
52
+ // 128-bits store
53
+ #pragma unroll
54
+ for (int j = 0; j < WPT / 2; j++) {
55
+ residual[j] = residual_buffer[j];
56
+ }
57
+
58
+ // Advance pointers
59
+ input += loop_stride;
60
+ residual += loop_stride;
61
+ }
62
+ variance /= cols;
63
+
64
+ // Block reduce to compute the total norm
65
+ __shared__ float shared_normalizer;
66
+ using BlockReduce = hipcub::BlockReduce<float, 1024>;
67
+ __shared__ typename BlockReduce::TempStorage reduceStore;
68
+
69
+ variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
70
+ if (threadIdx.x == 0) {
71
+ shared_normalizer = rsqrtf(variance + epsilon);
72
+ }
73
+ __syncthreads();
74
+
75
+ // Normalize and convert
76
+ float2 tmp_float2;
77
+ __half2 residual_buffer_[WPT / 2];
78
+ __half2 weight_buffer[WPT / 2];
79
+ __hip_fp8x2_storage_t fp8x2_buffer[WPT / 2];
80
+
81
+ residual -= iterations * loop_stride;
82
+ for (int i = 0; i < iterations; i++) {
83
+ // 128-bits loads
84
+ #pragma unroll
85
+ for (int j = 0; j < WPT / 2; j++) {
86
+ residual_buffer_[j] = residual[j];
87
+ }
88
+ #pragma unroll
89
+ for (int j = 0; j < WPT / 2; j++) {
90
+ weight_buffer[j] = weight[j];
91
+ }
92
+
93
+ // Compute and fill buffer
94
+ #pragma unroll
95
+ for (int j = 0; j < WPT / 2; j++) {
96
+ // .x
97
+ tmp_float2.x = (float)residual_buffer_[j].x * shared_normalizer;
98
+ tmp_float2.x = (float)((half)(tmp_float2.x) * weight_buffer[j].x);
99
+ tmp_float2.x *= scale;
100
+ FP8_CLAMP(tmp_float2.x, float);
101
+ // .y
102
+ tmp_float2.y = (float)residual_buffer_[j].y * shared_normalizer;
103
+ tmp_float2.y = (float)((half)(tmp_float2.y) * weight_buffer[j].y);
104
+ tmp_float2.y *= scale;
105
+ FP8_CLAMP(tmp_float2.y, float);
106
+ // convert
107
+ fp8x2_buffer[j] = __hip_cvt_float2_to_fp8x2(tmp_float2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
108
+ }
109
+
110
+ // 64b store
111
+ #pragma unroll
112
+ for (int j = 0; j < WPT / 2; j++) {
113
+ output[j] = fp8x2_buffer[j];
114
+ }
115
+
116
+ // Advance pointers
117
+ residual += loop_stride;
118
+ weight += loop_stride;
119
+ output += loop_stride;
120
+ }
121
+ }
122
+
123
+ #define LAUNCH_RESIDUAL_RMS_V4 \
124
+ (_residual_rms_v4<<<grid, block, 0, stream>>>((__half2*)input.data_ptr(), (__half2*)residual.data_ptr(), \
125
+ (__half2*)weight.data_ptr(), \
126
+ (__hip_fp8x2_storage_t*)output.data_ptr(), epsilon, scale, cols))
test/__init__.py ADDED
File without changes
test/kernels/__init__.py ADDED
File without changes
test/kernels/test_residual_rms.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the `residual_rms` kernel.
2
+
3
+ Run `pytest tests/kernels/test_residual_rms.py`.
4
+ """
5
+
6
+ from typing import List
7
+
8
+ import pytest
9
+ import torch
10
+
11
+ from residual_rms._ops import ops
12
+ from residual_rms.residual_rms import residual_rms
13
+
14
+
15
+ @pytest.mark.parametrize("shape", [(2, 3, 4, 5), (2, 3, 4, 5, 6)])
16
+ def test_residual_rms(shape: List[int]) -> None:
17
+ x = torch.randn(shape)
18
+ out = torch.zeros_like(x)
19
+ residual_rms(out, x)
20
+ assert torch.allclose(out, ops.residual_rms(x))