import torch from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): def quantize_global_transpose(input): return None def quantize_global(x: torch.Tensor): return None else: import triton import triton.language as tl # global quantize @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), triton.Config({"BLOCK_SIZE": 2048}, num_stages=1), ], key=["n_elements"], ) @triton.jit def _quantize_global( x_ptr, absmax_inv_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) tl.store(output_ptr + offsets, output, mask=mask) def quantize_global(x: torch.Tensor): absmax = x.abs().max().unsqueeze(0) absmax_inv = 1.0 / absmax output = torch.empty(*x.shape, device="cuda", dtype=torch.int8) assert x.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _quantize_global[grid](x, absmax_inv, output, n_elements) return output, absmax # global quantize and transpose @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), # ... ], key=["M", "N"], ) @triton.jit def _quantize_global_transpose( A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_M: tl.constexpr, ): pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // group_size rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) mask = (rm < M)[:, None] & (rn < N)[None, :] a = tl.load(A, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) # rematerialize to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) mask = (rm < M)[:, None] & (rn < N)[None, :] output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) tl.store(B, output, mask=mask) def quantize_global_transpose(input): absmax = input.abs().max().unsqueeze(0) absmax_inv = 1.0 / absmax M, N = input.shape out = torch.empty(N, M, device="cuda", dtype=torch.int8) assert out.size(0) == N and out.size(1) == M assert input.stride(0) == 1 or input.stride(1) == 1 assert out.stride(0) == 1 or out.stride(1) == 1 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _quantize_global_transpose[grid]( input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N, ) return out, absmax