drbh
commited on
Commit
·
79aac9d
0
Parent(s):
feat: impl residual rms kernel repo
Browse files- .gitattributes +36 -0
- .gitignore +1 -0
- README.md +31 -0
- build.toml +21 -0
- ext-torch/registration.h +27 -0
- ext-torch/residual_rms/__init__.py +16 -0
- ext-torch/torch_binding.cpp +16 -0
- ext-torch/torch_binding.h +12 -0
- flake.lock +95 -0
- flake.nix +14 -0
- residual_rms/compat.h +5 -0
- residual_rms/residual_rms_dispatch.cu +56 -0
- residual_rms/residual_rms_v0.cu +57 -0
- residual_rms/residual_rms_v1.cu +103 -0
- residual_rms/residual_rms_v2.cu +118 -0
- residual_rms/residual_rms_v3.cu +125 -0
- residual_rms/residual_rms_v4.cu +126 -0
- test/__init__.py +0 -0
- test/kernels/__init__.py +0 -0
- test/kernels/test_residual_rms.py +20 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
py_example
|
README.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
## Residual RMS for ROCM
|
6 |
+
|
7 |
+
Residual RMS kernels from [residual_rms](https://github.com/huggingface/hf-rocm-kernels).
|
8 |
+
|
9 |
+
# Development
|
10 |
+
|
11 |
+
This kernel can be built using the the [HF Kernel Builder](https://github.com/huggingface/kernel-builder) using the following the commands.
|
12 |
+
|
13 |
+
## Build
|
14 |
+
|
15 |
+
```bash
|
16 |
+
nix build .#bundle -L
|
17 |
+
```
|
18 |
+
|
19 |
+
### Dev shell
|
20 |
+
|
21 |
+
```bash
|
22 |
+
nix develop -L
|
23 |
+
pytest tests
|
24 |
+
```
|
25 |
+
|
26 |
+
## Publish
|
27 |
+
|
28 |
+
```bash
|
29 |
+
git remote add origin [email protected]:kernels-community/residual_rms
|
30 |
+
git push -u origin main
|
31 |
+
```
|
build.toml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
version = "0.0.1"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
name = "residual_rms"
|
6 |
+
src = [
|
7 |
+
"ext-torch/registration.h",
|
8 |
+
"ext-torch/torch_binding.cpp",
|
9 |
+
"ext-torch/torch_binding.h",
|
10 |
+
]
|
11 |
+
include = ["."]
|
12 |
+
pyroot = "ext-torch"
|
13 |
+
pyext = ["py", "json"]
|
14 |
+
|
15 |
+
[kernel.residual_rms]
|
16 |
+
capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
|
17 |
+
src = [
|
18 |
+
"residual_rms/residual_rms_dispatch.cu",
|
19 |
+
"residual_rms/compat.h",
|
20 |
+
]
|
21 |
+
depends = ["torch"]
|
ext-torch/registration.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <Python.h>
|
4 |
+
|
5 |
+
#define _CONCAT(A, B) A##B
|
6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
+
|
8 |
+
#define _STRINGIFY(A) #A
|
9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
+
|
11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
+
// could be a macro instead of a literal token.
|
13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
+
|
15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
+
// could be a macro instead of a literal token.
|
17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
+
|
20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
+
// via python's import statement.
|
22 |
+
#define REGISTER_EXTENSION(NAME) \
|
23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
+
return PyModule_Create(&module); \
|
27 |
+
}
|
ext-torch/residual_rms/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
try:
|
4 |
+
from ._ops import ops
|
5 |
+
except ImportError as e:
|
6 |
+
# Fallback for local development.
|
7 |
+
try:
|
8 |
+
import _residual_rms
|
9 |
+
|
10 |
+
ops = torch.ops._residual_rms
|
11 |
+
except ImportError:
|
12 |
+
raise e
|
13 |
+
|
14 |
+
def residual_rms(out: torch.Tensor, x: torch.Tensor) -> None:
|
15 |
+
ops.residual_rms(out, x)
|
16 |
+
return out
|
ext-torch/torch_binding.cpp
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
// Increment a tensor by 1.
|
8 |
+
ops.def("increment(Tensor x) -> ()");
|
9 |
+
ops.impl("increment", torch::kCUDA, &increment);
|
10 |
+
|
11 |
+
// Compute the residual root mean square.
|
12 |
+
ops.def("residual_rms(Tensor input, Tensor residual, Tensor weight, Tensor output, float epsilon, float scale, int mode, int num_threads) -> ()");
|
13 |
+
ops.impl("residual_rms", torch::kCUDA, &residual_rms);
|
14 |
+
}
|
15 |
+
|
16 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
ext-torch/torch_binding.h
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <optional>
|
4 |
+
#include <torch/library.h>
|
5 |
+
|
6 |
+
#include <vector>
|
7 |
+
|
8 |
+
void increment(torch::Tensor &x);
|
9 |
+
|
10 |
+
void residual_rms(torch::Tensor &input, torch::Tensor &residual,
|
11 |
+
torch::Tensor &weight, torch::Tensor &output, double epsilon,
|
12 |
+
double scale, int64_t mode, int64_t num_threads);
|
flake.lock
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1733328505,
|
6 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-utils": {
|
19 |
+
"inputs": {
|
20 |
+
"systems": "systems"
|
21 |
+
},
|
22 |
+
"locked": {
|
23 |
+
"lastModified": 1731533236,
|
24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
25 |
+
"owner": "numtide",
|
26 |
+
"repo": "flake-utils",
|
27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
28 |
+
"type": "github"
|
29 |
+
},
|
30 |
+
"original": {
|
31 |
+
"owner": "numtide",
|
32 |
+
"repo": "flake-utils",
|
33 |
+
"type": "github"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"kernel-builder": {
|
37 |
+
"inputs": {
|
38 |
+
"flake-compat": "flake-compat",
|
39 |
+
"flake-utils": "flake-utils",
|
40 |
+
"nixpkgs": "nixpkgs"
|
41 |
+
},
|
42 |
+
"locked": {
|
43 |
+
"lastModified": 1738315861,
|
44 |
+
"narHash": "sha256-QPWRaIPAMmQANuAOaZIKzh1e69OG8zBWGg+swESEajw=",
|
45 |
+
"ref": "refs/heads/main",
|
46 |
+
"rev": "eabeadcedba5dcef2a562b8f1ed5ec1feb485496",
|
47 |
+
"revCount": 72,
|
48 |
+
"type": "git",
|
49 |
+
"url": "ssh://[email protected]/huggingface/kernel-builder"
|
50 |
+
},
|
51 |
+
"original": {
|
52 |
+
"type": "git",
|
53 |
+
"url": "ssh://[email protected]/huggingface/kernel-builder"
|
54 |
+
}
|
55 |
+
},
|
56 |
+
"nixpkgs": {
|
57 |
+
"locked": {
|
58 |
+
"lastModified": 1738247409,
|
59 |
+
"narHash": "sha256-F72dKl9Na6/2N+garOm9qCXPa92GzR8eYSuDra6kbjY=",
|
60 |
+
"owner": "danieldk",
|
61 |
+
"repo": "nixpkgs",
|
62 |
+
"rev": "358f57074b70e3ee9e1dc118151a4f6f81fcd3bb",
|
63 |
+
"type": "github"
|
64 |
+
},
|
65 |
+
"original": {
|
66 |
+
"owner": "danieldk",
|
67 |
+
"ref": "cuda-12.6-for-kernel-builder",
|
68 |
+
"repo": "nixpkgs",
|
69 |
+
"type": "github"
|
70 |
+
}
|
71 |
+
},
|
72 |
+
"root": {
|
73 |
+
"inputs": {
|
74 |
+
"kernel-builder": "kernel-builder"
|
75 |
+
}
|
76 |
+
},
|
77 |
+
"systems": {
|
78 |
+
"locked": {
|
79 |
+
"lastModified": 1681028828,
|
80 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
81 |
+
"owner": "nix-systems",
|
82 |
+
"repo": "default",
|
83 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
84 |
+
"type": "github"
|
85 |
+
},
|
86 |
+
"original": {
|
87 |
+
"owner": "nix-systems",
|
88 |
+
"repo": "default",
|
89 |
+
"type": "github"
|
90 |
+
}
|
91 |
+
}
|
92 |
+
},
|
93 |
+
"root": "root",
|
94 |
+
"version": 7
|
95 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for rocm residual rms kernels";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "git+ssh://[email protected]/huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs ./.;
|
14 |
+
}
|
residual_rms/compat.h
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <hip/hip_runtime.h>
|
4 |
+
|
5 |
+
#define WARP_SIZE 32
|
residual_rms/residual_rms_dispatch.cu
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <c10/cuda/CUDAGuard.h>
|
3 |
+
#include <hip/hip_runtime.h>
|
4 |
+
|
5 |
+
#include "op_src/residual_rms/residual_rms_v0.cu"
|
6 |
+
#include "op_src/residual_rms/residual_rms_v1.cu"
|
7 |
+
#include "op_src/residual_rms/residual_rms_v2.cu"
|
8 |
+
#include "op_src/residual_rms/residual_rms_v3.cu"
|
9 |
+
#include "op_src/residual_rms/residual_rms_v4.cu"
|
10 |
+
|
11 |
+
void residual_rms(torch::Tensor& input, // Shape: [m, n] / Layout: row-major / Dtype: fp16
|
12 |
+
torch::Tensor& residual, // Shape: [m, n] / Layout: row-major / Dtype: fp16
|
13 |
+
torch::Tensor& weight, // Shape: [m, ] / Layout: row-major / Dtype: fp16
|
14 |
+
torch::Tensor& output, // Shape: [m, n] / Layout: row-major / Dtype: fp8
|
15 |
+
double epsilon, double scale, int64_t mode,
|
16 |
+
int64_t num_threads) { // TODO: add fp16 output mode
|
17 |
+
|
18 |
+
// Retrieve shapes
|
19 |
+
const int rows = input.size(0);
|
20 |
+
const int cols = input.size(1);
|
21 |
+
// Activate device guard
|
22 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
23 |
+
|
24 |
+
// Prepare kernel launch arguments
|
25 |
+
dim3 grid(rows);
|
26 |
+
dim3 block(num_threads);
|
27 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
28 |
+
|
29 |
+
// Launch kernel
|
30 |
+
switch (mode) {
|
31 |
+
case 1:
|
32 |
+
LAUNCH_RESIDUAL_RMS_V1;
|
33 |
+
break;
|
34 |
+
case 2:
|
35 |
+
LAUNCH_RESIDUAL_RMS_V2;
|
36 |
+
break;
|
37 |
+
case 3:
|
38 |
+
LAUNCH_RESIDUAL_RMS_V3;
|
39 |
+
break;
|
40 |
+
case 4:
|
41 |
+
LAUNCH_RESIDUAL_RMS_V4;
|
42 |
+
break;
|
43 |
+
default:
|
44 |
+
LAUNCH_RESIDUAL_RMS_V0;
|
45 |
+
break;
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
/*
|
50 |
+
Versions:
|
51 |
+
0. non-vectorized version
|
52 |
+
1. vectorizes loads and stores
|
53 |
+
2. simplified indexing
|
54 |
+
3. added packed conversion
|
55 |
+
4. using packed types everywhere and custom ASM for residual connection and variance
|
56 |
+
*/
|
residual_rms/residual_rms_v0.cu
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
|
3 |
+
#include <hip/hip_bf16.h>
|
4 |
+
#include <hip/hip_fp16.h>
|
5 |
+
#include <hipcub/util_type.hpp>
|
6 |
+
#include <hipcub/hipcub.hpp>
|
7 |
+
#include <hip/hip_fp8.h>
|
8 |
+
|
9 |
+
#include "utils/macros.h"
|
10 |
+
|
11 |
+
__global__ void _residual_rms_v0(const half* __restrict__ input, half* __restrict__ residual,
|
12 |
+
const half* __restrict__ weight, __hip_fp8_storage_t* __restrict__ output,
|
13 |
+
const float epsilon, const float scale, const int cols) {
|
14 |
+
// Advance pointers according to the position of the thread in the grid
|
15 |
+
input += blockIdx.x * cols;
|
16 |
+
residual += blockIdx.x * cols;
|
17 |
+
output += blockIdx.x * cols;
|
18 |
+
|
19 |
+
// Residual connection: inplace add of input to residual, accumulate norm along the way
|
20 |
+
float variance = 0.0f;
|
21 |
+
|
22 |
+
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
23 |
+
half z = input[i];
|
24 |
+
z += residual[i];
|
25 |
+
float x = (float)z;
|
26 |
+
variance += (x * x);
|
27 |
+
residual[i] = z;
|
28 |
+
}
|
29 |
+
variance /= cols;
|
30 |
+
|
31 |
+
// Block reduce to compute the total norm
|
32 |
+
__shared__ float shared_normalizer;
|
33 |
+
using BlockReduce = hipcub::BlockReduce<float, 1024>;
|
34 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
35 |
+
|
36 |
+
variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
|
37 |
+
if (threadIdx.x == 0) {
|
38 |
+
shared_normalizer = rsqrtf(variance + epsilon);
|
39 |
+
}
|
40 |
+
__syncthreads();
|
41 |
+
|
42 |
+
// Normalize and convert
|
43 |
+
for (int idx = threadIdx.x; idx < cols; idx += blockDim.x) {
|
44 |
+
float x = (float)residual[idx];
|
45 |
+
half y = (half)(x * shared_normalizer);
|
46 |
+
y = (y * weight[idx]);
|
47 |
+
x = (float)y;
|
48 |
+
x *= scale;
|
49 |
+
FP8_CLAMP(x, float);
|
50 |
+
output[idx] = __hip_cvt_float_to_fp8(x, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
#define LAUNCH_RESIDUAL_RMS_V0 \
|
55 |
+
(_residual_rms_v0<<<grid, block, 0, stream>>>((half*)input.data_ptr(), (half*)residual.data_ptr(), \
|
56 |
+
(half*)weight.data_ptr(), (__hip_fp8_storage_t*)output.data_ptr(), \
|
57 |
+
epsilon, scale, cols))
|
residual_rms/residual_rms_v1.cu
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
|
3 |
+
#include <hip/hip_bf16.h>
|
4 |
+
#include <hip/hip_fp16.h>
|
5 |
+
#include <hipcub/util_type.hpp>
|
6 |
+
#include <hipcub/hipcub.hpp>
|
7 |
+
#include <hip/hip_fp8.h>
|
8 |
+
|
9 |
+
#include "utils/macros.h"
|
10 |
+
|
11 |
+
#define WPT 8 // WorkPerThreads
|
12 |
+
|
13 |
+
__global__ void _residual_rms_v1(const half* __restrict__ input, half* __restrict__ residual,
|
14 |
+
const half* __restrict__ weight, __hip_fp8_storage_t* __restrict__ output,
|
15 |
+
const float epsilon, const float scale, const int cols) {
|
16 |
+
// Advance pointers according to the position of the thread in the grid
|
17 |
+
input += blockIdx.x * cols;
|
18 |
+
residual += blockIdx.x * cols;
|
19 |
+
output += blockIdx.x * cols;
|
20 |
+
|
21 |
+
// Residual connection: inplace add of input to residual, accumulate norm along the way
|
22 |
+
float variance = 0.0f;
|
23 |
+
float fp32_residual;
|
24 |
+
half input_buffer[WPT];
|
25 |
+
half residual_buffer[WPT];
|
26 |
+
|
27 |
+
for (int i = WPT * threadIdx.x; i < cols; i += WPT * blockDim.x) {
|
28 |
+
// Load data using 128-bits loads
|
29 |
+
#pragma unroll
|
30 |
+
for (int j = 0; j < WPT; j++) {
|
31 |
+
input_buffer[j] = input[i + j];
|
32 |
+
}
|
33 |
+
#pragma unroll
|
34 |
+
for (int j = 0; j < WPT; j++) {
|
35 |
+
residual_buffer[j] = residual[i + j];
|
36 |
+
}
|
37 |
+
|
38 |
+
// Add everything in the residual buffer and accumulate variance
|
39 |
+
#pragma unroll
|
40 |
+
for (int j = 0; j < WPT; j++) {
|
41 |
+
residual_buffer[j] += input_buffer[j];
|
42 |
+
fp32_residual = (float)residual_buffer[j];
|
43 |
+
variance += fp32_residual * fp32_residual;
|
44 |
+
}
|
45 |
+
|
46 |
+
// 128-bits store
|
47 |
+
#pragma unroll
|
48 |
+
for (int j = 0; j < WPT; j++) {
|
49 |
+
residual[i + j] = residual_buffer[j];
|
50 |
+
}
|
51 |
+
}
|
52 |
+
variance /= cols;
|
53 |
+
|
54 |
+
// Block reduce to compute the total norm
|
55 |
+
__shared__ float shared_normalizer;
|
56 |
+
using BlockReduce = hipcub::BlockReduce<float, 1024>;
|
57 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
58 |
+
|
59 |
+
variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
|
60 |
+
if (threadIdx.x == 0) {
|
61 |
+
shared_normalizer = rsqrtf(variance + epsilon);
|
62 |
+
}
|
63 |
+
__syncthreads();
|
64 |
+
|
65 |
+
// Normalize and convert
|
66 |
+
float tmp_float;
|
67 |
+
half residual_buffer_[WPT];
|
68 |
+
half weight_buffer[WPT];
|
69 |
+
__hip_fp8_storage_t fp8_buffer[WPT];
|
70 |
+
|
71 |
+
for (int i = WPT * threadIdx.x; i < cols; i += WPT * blockDim.x) {
|
72 |
+
// 128-bits loads
|
73 |
+
#pragma unroll
|
74 |
+
for (int j = 0; j < WPT; j++) {
|
75 |
+
residual_buffer_[j] = residual[i + j];
|
76 |
+
}
|
77 |
+
#pragma unroll
|
78 |
+
for (int j = 0; j < WPT; j++) {
|
79 |
+
weight_buffer[j] = weight[i + j];
|
80 |
+
}
|
81 |
+
|
82 |
+
// Compute and fill buffer
|
83 |
+
#pragma unroll
|
84 |
+
for (int j = 0; j < WPT; j++) {
|
85 |
+
tmp_float = (float)residual_buffer_[j] * shared_normalizer;
|
86 |
+
tmp_float = (float)((half)(tmp_float)*weight_buffer[j]);
|
87 |
+
tmp_float *= scale;
|
88 |
+
FP8_CLAMP(tmp_float, float);
|
89 |
+
fp8_buffer[j] = __hip_cvt_float_to_fp8(tmp_float, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
90 |
+
}
|
91 |
+
|
92 |
+
// 64b store
|
93 |
+
#pragma unroll
|
94 |
+
for (int j = 0; j < WPT; j++) {
|
95 |
+
output[i + j] = fp8_buffer[j];
|
96 |
+
}
|
97 |
+
}
|
98 |
+
}
|
99 |
+
|
100 |
+
#define LAUNCH_RESIDUAL_RMS_V1 \
|
101 |
+
(_residual_rms_v1<<<grid, block, 0, stream>>>((half*)input.data_ptr(), (half*)residual.data_ptr(), \
|
102 |
+
(half*)weight.data_ptr(), (__hip_fp8_storage_t*)output.data_ptr(), \
|
103 |
+
epsilon, scale, cols))
|
residual_rms/residual_rms_v2.cu
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
|
3 |
+
#include <hip/hip_bf16.h>
|
4 |
+
#include <hip/hip_fp16.h>
|
5 |
+
#include <hipcub/util_type.hpp>
|
6 |
+
#include <hipcub/hipcub.hpp>
|
7 |
+
#include <hip/hip_fp8.h>
|
8 |
+
|
9 |
+
#include "utils/macros.h"
|
10 |
+
|
11 |
+
#define WPT 8 // WorkPerThreads
|
12 |
+
#define CDIV(a, b) ((a + b - 1) / (b)) // Ceiling division
|
13 |
+
|
14 |
+
__global__ void _residual_rms_v2(const half* __restrict__ input, half* __restrict__ residual,
|
15 |
+
const half* __restrict__ weight, __hip_fp8_storage_t* __restrict__ output,
|
16 |
+
const float epsilon, const float scale, const int cols) {
|
17 |
+
// Advance pointers according to the position of the thread in the grid
|
18 |
+
input += blockIdx.x * cols + WPT * threadIdx.x;
|
19 |
+
residual += blockIdx.x * cols + WPT * threadIdx.x;
|
20 |
+
weight += WPT * threadIdx.x;
|
21 |
+
output += blockIdx.x * cols + WPT * threadIdx.x;
|
22 |
+
half* residual_start = residual;
|
23 |
+
|
24 |
+
// Residual connection: inplace add of input to residual, accumulate norm along the way
|
25 |
+
float variance = 0.0f;
|
26 |
+
float fp32_residual;
|
27 |
+
half input_buffer[WPT];
|
28 |
+
half residual_buffer[WPT];
|
29 |
+
|
30 |
+
const int loop_stride = WPT * blockDim.x;
|
31 |
+
const int iterations = CDIV(cols - WPT * threadIdx.x, loop_stride);
|
32 |
+
for (int i = 0; i < iterations; i++) {
|
33 |
+
// Load data using 128-bits loads
|
34 |
+
#pragma unroll
|
35 |
+
for (int j = 0; j < WPT; j++) {
|
36 |
+
input_buffer[j] = input[j];
|
37 |
+
}
|
38 |
+
#pragma unroll
|
39 |
+
for (int j = 0; j < WPT; j++) {
|
40 |
+
residual_buffer[j] = residual[j];
|
41 |
+
}
|
42 |
+
|
43 |
+
// Add everything in the residual buffer and accumulate variance
|
44 |
+
#pragma unroll
|
45 |
+
for (int j = 0; j < WPT; j++) {
|
46 |
+
residual_buffer[j] += input_buffer[j];
|
47 |
+
fp32_residual = (float)residual_buffer[j];
|
48 |
+
variance += fp32_residual * fp32_residual;
|
49 |
+
}
|
50 |
+
|
51 |
+
// 128-bits store
|
52 |
+
#pragma unroll
|
53 |
+
for (int j = 0; j < WPT; j++) {
|
54 |
+
residual[j] = residual_buffer[j];
|
55 |
+
}
|
56 |
+
|
57 |
+
// Advance pointers
|
58 |
+
input += loop_stride;
|
59 |
+
residual += loop_stride;
|
60 |
+
}
|
61 |
+
variance /= cols;
|
62 |
+
|
63 |
+
// Block reduce to compute the total norm
|
64 |
+
__shared__ float shared_normalizer;
|
65 |
+
using BlockReduce = hipcub::BlockReduce<float, 1024>;
|
66 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
67 |
+
|
68 |
+
variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
|
69 |
+
if (threadIdx.x == 0) {
|
70 |
+
shared_normalizer = rsqrtf(variance + epsilon);
|
71 |
+
}
|
72 |
+
__syncthreads();
|
73 |
+
|
74 |
+
// Normalize and convert
|
75 |
+
float tmp_float;
|
76 |
+
half residual_buffer_[WPT];
|
77 |
+
half weight_buffer[WPT];
|
78 |
+
__hip_fp8_storage_t fp8_buffer[WPT];
|
79 |
+
|
80 |
+
residual = residual_start;
|
81 |
+
for (int i = 0; i < iterations; i++) {
|
82 |
+
// 128-bits loads
|
83 |
+
#pragma unroll
|
84 |
+
for (int j = 0; j < WPT; j++) {
|
85 |
+
residual_buffer_[j] = residual[j];
|
86 |
+
}
|
87 |
+
#pragma unroll
|
88 |
+
for (int j = 0; j < WPT; j++) {
|
89 |
+
weight_buffer[j] = weight[j];
|
90 |
+
}
|
91 |
+
|
92 |
+
// Compute and fill buffer
|
93 |
+
#pragma unroll
|
94 |
+
for (int j = 0; j < WPT; j++) {
|
95 |
+
tmp_float = (float)residual_buffer_[j] * shared_normalizer;
|
96 |
+
tmp_float = (float)((half)(tmp_float)*weight_buffer[j]);
|
97 |
+
tmp_float *= scale;
|
98 |
+
FP8_CLAMP(tmp_float, float);
|
99 |
+
fp8_buffer[j] = __hip_cvt_float_to_fp8(tmp_float, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
100 |
+
}
|
101 |
+
|
102 |
+
// 64b store
|
103 |
+
#pragma unroll
|
104 |
+
for (int j = 0; j < WPT; j++) {
|
105 |
+
output[j] = fp8_buffer[j];
|
106 |
+
}
|
107 |
+
|
108 |
+
// Advance pointers
|
109 |
+
residual += loop_stride;
|
110 |
+
weight += loop_stride;
|
111 |
+
output += loop_stride;
|
112 |
+
}
|
113 |
+
}
|
114 |
+
|
115 |
+
#define LAUNCH_RESIDUAL_RMS_V2 \
|
116 |
+
(_residual_rms_v2<<<grid, block, 0, stream>>>((half*)input.data_ptr(), (half*)residual.data_ptr(), \
|
117 |
+
(half*)weight.data_ptr(), (__hip_fp8_storage_t*)output.data_ptr(), \
|
118 |
+
epsilon, scale, cols))
|
residual_rms/residual_rms_v3.cu
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
|
3 |
+
#include <hip/hip_bf16.h>
|
4 |
+
#include <hip/hip_fp16.h>
|
5 |
+
#include <hipcub/util_type.hpp>
|
6 |
+
#include <hipcub/hipcub.hpp>
|
7 |
+
#include <hip/hip_fp8.h>
|
8 |
+
|
9 |
+
#include "utils/macros.h"
|
10 |
+
|
11 |
+
#define WPT 8 // WorkPerThreads
|
12 |
+
#define CDIV(a, b) ((a + b - 1) / (b)) // Ceiling division
|
13 |
+
|
14 |
+
__global__ void _residual_rms_v3(const half* __restrict__ input, half* __restrict__ residual,
|
15 |
+
const half* __restrict__ weight, __hip_fp8x2_storage_t* __restrict__ output,
|
16 |
+
const float epsilon, const float scale, const int cols) {
|
17 |
+
// Advance pointers according to the position of the thread in the grid
|
18 |
+
input += blockIdx.x * cols + WPT * threadIdx.x;
|
19 |
+
residual += blockIdx.x * cols + WPT * threadIdx.x;
|
20 |
+
weight += WPT * threadIdx.x;
|
21 |
+
output += (blockIdx.x * cols + WPT * threadIdx.x) / 2;
|
22 |
+
half* residual_start = residual;
|
23 |
+
|
24 |
+
// Residual connection: inplace add of input to residual, accumulate norm along the way
|
25 |
+
float variance = 0.0f;
|
26 |
+
float fp32_residual;
|
27 |
+
half input_buffer[WPT];
|
28 |
+
half residual_buffer[WPT];
|
29 |
+
|
30 |
+
const int loop_stride = WPT * blockDim.x;
|
31 |
+
const int iterations = CDIV(cols - WPT * threadIdx.x, loop_stride);
|
32 |
+
for (int i = 0; i < iterations; i++) {
|
33 |
+
// Load data using 128-bits loads
|
34 |
+
#pragma unroll
|
35 |
+
for (int j = 0; j < WPT; j++) {
|
36 |
+
input_buffer[j] = input[j];
|
37 |
+
}
|
38 |
+
#pragma unroll
|
39 |
+
for (int j = 0; j < WPT; j++) {
|
40 |
+
residual_buffer[j] = residual[j];
|
41 |
+
}
|
42 |
+
|
43 |
+
// Add everything in the residual buffer and accumulate variance
|
44 |
+
#pragma unroll
|
45 |
+
for (int j = 0; j < WPT; j++) {
|
46 |
+
residual_buffer[j] += input_buffer[j];
|
47 |
+
fp32_residual = (float)residual_buffer[j];
|
48 |
+
variance += fp32_residual * fp32_residual;
|
49 |
+
}
|
50 |
+
|
51 |
+
// 128-bits store
|
52 |
+
#pragma unroll
|
53 |
+
for (int j = 0; j < WPT; j++) {
|
54 |
+
residual[j] = residual_buffer[j];
|
55 |
+
}
|
56 |
+
|
57 |
+
// Advance pointers
|
58 |
+
input += loop_stride;
|
59 |
+
residual += loop_stride;
|
60 |
+
}
|
61 |
+
variance /= cols;
|
62 |
+
|
63 |
+
// Block reduce to compute the total norm
|
64 |
+
__shared__ float shared_normalizer;
|
65 |
+
using BlockReduce = hipcub::BlockReduce<float, 1024>;
|
66 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
67 |
+
|
68 |
+
variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
|
69 |
+
if (threadIdx.x == 0) {
|
70 |
+
shared_normalizer = rsqrtf(variance + epsilon);
|
71 |
+
}
|
72 |
+
__syncthreads();
|
73 |
+
|
74 |
+
// Normalize and convert
|
75 |
+
float2 tmp_float2;
|
76 |
+
half residual_buffer_[WPT];
|
77 |
+
half weight_buffer[WPT];
|
78 |
+
__hip_fp8x2_storage_t fp8x2_buffer[WPT / 2];
|
79 |
+
|
80 |
+
residual = residual_start;
|
81 |
+
for (int i = 0; i < iterations; i++) {
|
82 |
+
// 128-bits loads
|
83 |
+
#pragma unroll
|
84 |
+
for (int j = 0; j < WPT; j++) {
|
85 |
+
residual_buffer_[j] = residual[j];
|
86 |
+
}
|
87 |
+
#pragma unroll
|
88 |
+
for (int j = 0; j < WPT; j++) {
|
89 |
+
weight_buffer[j] = weight[j];
|
90 |
+
}
|
91 |
+
|
92 |
+
// Compute and fill buffer
|
93 |
+
#pragma unroll
|
94 |
+
for (int j = 0; j < WPT / 2; j++) {
|
95 |
+
// .x
|
96 |
+
tmp_float2.x = (float)residual_buffer_[2 * j] * shared_normalizer;
|
97 |
+
tmp_float2.x = (float)((half)(tmp_float2.x) * weight_buffer[2 * j]);
|
98 |
+
tmp_float2.x *= scale;
|
99 |
+
FP8_CLAMP(tmp_float2.x, float);
|
100 |
+
// .y
|
101 |
+
tmp_float2.y = (float)residual_buffer_[2 * j + 1] * shared_normalizer;
|
102 |
+
tmp_float2.y = (float)((half)(tmp_float2.y) * weight_buffer[2 * j + 1]);
|
103 |
+
tmp_float2.y *= scale;
|
104 |
+
FP8_CLAMP(tmp_float2.y, float);
|
105 |
+
// convert
|
106 |
+
fp8x2_buffer[j] = __hip_cvt_float2_to_fp8x2(tmp_float2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
107 |
+
}
|
108 |
+
|
109 |
+
// 64b store
|
110 |
+
#pragma unroll
|
111 |
+
for (int j = 0; j < WPT / 2; j++) {
|
112 |
+
output[j] = fp8x2_buffer[j];
|
113 |
+
}
|
114 |
+
|
115 |
+
// Advance pointers
|
116 |
+
residual += loop_stride;
|
117 |
+
weight += loop_stride;
|
118 |
+
output += loop_stride / 2;
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
#define LAUNCH_RESIDUAL_RMS_V3 \
|
123 |
+
(_residual_rms_v3<<<grid, block, 0, stream>>>((half*)input.data_ptr(), (half*)residual.data_ptr(), \
|
124 |
+
(half*)weight.data_ptr(), (__hip_fp8x2_storage_t*)output.data_ptr(), \
|
125 |
+
epsilon, scale, cols))
|
residual_rms/residual_rms_v4.cu
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
|
3 |
+
#include <hip/hip_bf16.h>
|
4 |
+
#include <hip/hip_fp16.h>
|
5 |
+
#include <hipcub/util_type.hpp>
|
6 |
+
#include <hipcub/hipcub.hpp>
|
7 |
+
#include <hip/hip_fp8.h>
|
8 |
+
|
9 |
+
#include "utils/macros.h"
|
10 |
+
|
11 |
+
#define WPT 8 // WorkPerThreads
|
12 |
+
#define CDIV(a, b) ((a + b - 1) / (b)) // Ceiling division
|
13 |
+
|
14 |
+
__global__ void _residual_rms_v4(const __half2* __restrict__ input, __half2* __restrict__ residual,
|
15 |
+
const __half2* __restrict__ weight, __hip_fp8x2_storage_t* __restrict__ output,
|
16 |
+
const float epsilon, const float scale, const int cols) {
|
17 |
+
// Advance pointers according to the position of the thread in the grid
|
18 |
+
input += (blockIdx.x * cols + WPT * threadIdx.x) / 2;
|
19 |
+
residual += (blockIdx.x * cols + WPT * threadIdx.x) / 2;
|
20 |
+
weight += (WPT * threadIdx.x) / 2;
|
21 |
+
output += (blockIdx.x * cols + WPT * threadIdx.x) / 2;
|
22 |
+
|
23 |
+
// Residual connection: inplace add of input to residual, accumulate norm along the way
|
24 |
+
float variance = 0.0f;
|
25 |
+
float fp32_residual;
|
26 |
+
__half2 input_buffer[WPT / 2];
|
27 |
+
__half2 residual_buffer[WPT / 2];
|
28 |
+
|
29 |
+
const int loop_stride = blockDim.x * (WPT / 2);
|
30 |
+
const int iterations = CDIV(cols - WPT * threadIdx.x, 2 * loop_stride);
|
31 |
+
for (int i = 0; i < iterations; i++) {
|
32 |
+
// Load data using 128-bits loads
|
33 |
+
#pragma unroll
|
34 |
+
for (int j = 0; j < WPT / 2; j++) {
|
35 |
+
input_buffer[j] = input[j];
|
36 |
+
}
|
37 |
+
#pragma unroll
|
38 |
+
for (int j = 0; j < WPT / 2; j++) {
|
39 |
+
residual_buffer[j] = residual[j];
|
40 |
+
}
|
41 |
+
|
42 |
+
// Residual connection and variance accumulation
|
43 |
+
#pragma unroll
|
44 |
+
for (int j = 0; j < WPT / 2; j++) {
|
45 |
+
asm volatile(
|
46 |
+
"V_PK_ADD_F16 %0, %2, %3\n\t"
|
47 |
+
"V_DOT2C_F32_F16 %1, %2, %2"
|
48 |
+
: "=v"(residual_buffer[j]), "=v"(variance)
|
49 |
+
: "0"(residual_buffer[j]), "v"(input_buffer[j]));
|
50 |
+
}
|
51 |
+
|
52 |
+
// 128-bits store
|
53 |
+
#pragma unroll
|
54 |
+
for (int j = 0; j < WPT / 2; j++) {
|
55 |
+
residual[j] = residual_buffer[j];
|
56 |
+
}
|
57 |
+
|
58 |
+
// Advance pointers
|
59 |
+
input += loop_stride;
|
60 |
+
residual += loop_stride;
|
61 |
+
}
|
62 |
+
variance /= cols;
|
63 |
+
|
64 |
+
// Block reduce to compute the total norm
|
65 |
+
__shared__ float shared_normalizer;
|
66 |
+
using BlockReduce = hipcub::BlockReduce<float, 1024>;
|
67 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
68 |
+
|
69 |
+
variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
|
70 |
+
if (threadIdx.x == 0) {
|
71 |
+
shared_normalizer = rsqrtf(variance + epsilon);
|
72 |
+
}
|
73 |
+
__syncthreads();
|
74 |
+
|
75 |
+
// Normalize and convert
|
76 |
+
float2 tmp_float2;
|
77 |
+
__half2 residual_buffer_[WPT / 2];
|
78 |
+
__half2 weight_buffer[WPT / 2];
|
79 |
+
__hip_fp8x2_storage_t fp8x2_buffer[WPT / 2];
|
80 |
+
|
81 |
+
residual -= iterations * loop_stride;
|
82 |
+
for (int i = 0; i < iterations; i++) {
|
83 |
+
// 128-bits loads
|
84 |
+
#pragma unroll
|
85 |
+
for (int j = 0; j < WPT / 2; j++) {
|
86 |
+
residual_buffer_[j] = residual[j];
|
87 |
+
}
|
88 |
+
#pragma unroll
|
89 |
+
for (int j = 0; j < WPT / 2; j++) {
|
90 |
+
weight_buffer[j] = weight[j];
|
91 |
+
}
|
92 |
+
|
93 |
+
// Compute and fill buffer
|
94 |
+
#pragma unroll
|
95 |
+
for (int j = 0; j < WPT / 2; j++) {
|
96 |
+
// .x
|
97 |
+
tmp_float2.x = (float)residual_buffer_[j].x * shared_normalizer;
|
98 |
+
tmp_float2.x = (float)((half)(tmp_float2.x) * weight_buffer[j].x);
|
99 |
+
tmp_float2.x *= scale;
|
100 |
+
FP8_CLAMP(tmp_float2.x, float);
|
101 |
+
// .y
|
102 |
+
tmp_float2.y = (float)residual_buffer_[j].y * shared_normalizer;
|
103 |
+
tmp_float2.y = (float)((half)(tmp_float2.y) * weight_buffer[j].y);
|
104 |
+
tmp_float2.y *= scale;
|
105 |
+
FP8_CLAMP(tmp_float2.y, float);
|
106 |
+
// convert
|
107 |
+
fp8x2_buffer[j] = __hip_cvt_float2_to_fp8x2(tmp_float2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
108 |
+
}
|
109 |
+
|
110 |
+
// 64b store
|
111 |
+
#pragma unroll
|
112 |
+
for (int j = 0; j < WPT / 2; j++) {
|
113 |
+
output[j] = fp8x2_buffer[j];
|
114 |
+
}
|
115 |
+
|
116 |
+
// Advance pointers
|
117 |
+
residual += loop_stride;
|
118 |
+
weight += loop_stride;
|
119 |
+
output += loop_stride;
|
120 |
+
}
|
121 |
+
}
|
122 |
+
|
123 |
+
#define LAUNCH_RESIDUAL_RMS_V4 \
|
124 |
+
(_residual_rms_v4<<<grid, block, 0, stream>>>((__half2*)input.data_ptr(), (__half2*)residual.data_ptr(), \
|
125 |
+
(__half2*)weight.data_ptr(), \
|
126 |
+
(__hip_fp8x2_storage_t*)output.data_ptr(), epsilon, scale, cols))
|
test/__init__.py
ADDED
File without changes
|
test/kernels/__init__.py
ADDED
File without changes
|
test/kernels/test_residual_rms.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tests for the `residual_rms` kernel.
|
2 |
+
|
3 |
+
Run `pytest tests/kernels/test_residual_rms.py`.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import pytest
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from residual_rms._ops import ops
|
12 |
+
from residual_rms.residual_rms import residual_rms
|
13 |
+
|
14 |
+
|
15 |
+
@pytest.mark.parametrize("shape", [(2, 3, 4, 5), (2, 3, 4, 5, 6)])
|
16 |
+
def test_residual_rms(shape: List[int]) -> None:
|
17 |
+
x = torch.randn(shape)
|
18 |
+
out = torch.zeros_like(x)
|
19 |
+
residual_rms(out, x)
|
20 |
+
assert torch.allclose(out, ops.residual_rms(x))
|