|
import torch |
|
|
|
from bitsandbytes.triton.triton_utils import is_triton_available |
|
|
|
if not is_triton_available(): |
|
|
|
def quantize_global_transpose(input): |
|
return None |
|
|
|
def quantize_global(x: torch.Tensor): |
|
return None |
|
else: |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), |
|
triton.Config({"BLOCK_SIZE": 2048}, num_stages=1), |
|
], |
|
key=["n_elements"], |
|
) |
|
@triton.jit |
|
def _quantize_global( |
|
x_ptr, |
|
absmax_inv_ptr, |
|
output_ptr, |
|
n_elements, |
|
BLOCK_SIZE: tl.constexpr, |
|
): |
|
pid = tl.program_id(axis=0) |
|
block_start = pid * BLOCK_SIZE |
|
offsets = block_start + tl.arange(0, BLOCK_SIZE) |
|
mask = offsets < n_elements |
|
x = tl.load(x_ptr + offsets, mask=mask) |
|
absmax_inv = tl.load(absmax_inv_ptr) |
|
output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) |
|
tl.store(output_ptr + offsets, output, mask=mask) |
|
|
|
def quantize_global(x: torch.Tensor): |
|
absmax = x.abs().max().unsqueeze(0) |
|
absmax_inv = 1.0 / absmax |
|
output = torch.empty(*x.shape, device="cuda", dtype=torch.int8) |
|
assert x.is_cuda and output.is_cuda |
|
n_elements = output.numel() |
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
|
_quantize_global[grid](x, absmax_inv, output, n_elements) |
|
return output, absmax |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), |
|
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), |
|
|
|
], |
|
key=["M", "N"], |
|
) |
|
@triton.jit |
|
def _quantize_global_transpose( |
|
A, |
|
absmax_inv_ptr, |
|
B, |
|
stride_am, |
|
stride_an, |
|
stride_bn, |
|
stride_bm, |
|
M, |
|
N, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
GROUP_M: tl.constexpr, |
|
): |
|
pid = tl.program_id(0) |
|
grid_m = (M + BLOCK_M - 1) // BLOCK_M |
|
grid_n = (N + BLOCK_N - 1) // BLOCK_N |
|
|
|
width = GROUP_M * grid_n |
|
group_id = pid // width |
|
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) |
|
pid_m = group_id * GROUP_M + (pid % group_size) |
|
pid_n = (pid % width) // group_size |
|
|
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) |
|
mask = (rm < M)[:, None] & (rn < N)[None, :] |
|
a = tl.load(A, mask=mask) |
|
absmax_inv = tl.load(absmax_inv_ptr) |
|
|
|
|
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) |
|
mask = (rm < M)[:, None] & (rn < N)[None, :] |
|
|
|
output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) |
|
|
|
tl.store(B, output, mask=mask) |
|
|
|
def quantize_global_transpose(input): |
|
absmax = input.abs().max().unsqueeze(0) |
|
absmax_inv = 1.0 / absmax |
|
M, N = input.shape |
|
out = torch.empty(N, M, device="cuda", dtype=torch.int8) |
|
|
|
assert out.size(0) == N and out.size(1) == M |
|
assert input.stride(0) == 1 or input.stride(1) == 1 |
|
assert out.stride(0) == 1 or out.stride(1) == 1 |
|
|
|
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) |
|
_quantize_global_transpose[grid]( |
|
input, |
|
absmax_inv, |
|
out, |
|
input.stride(0), |
|
input.stride(1), |
|
out.stride(0), |
|
out.stride(1), |
|
M, |
|
N, |
|
) |
|
return out, absmax |
|
|