File size: 3,921 Bytes
6d63e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch

from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():

    def quantize_global_transpose(input):
        return None

    def quantize_global(x: torch.Tensor):
        return None
else:
    import triton
    import triton.language as tl

    # global quantize
    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
            triton.Config({"BLOCK_SIZE": 2048}, num_stages=1),
        ],
        key=["n_elements"],
    )
    @triton.jit
    def _quantize_global(
        x_ptr,
        absmax_inv_ptr,
        output_ptr,
        n_elements,
        BLOCK_SIZE: tl.constexpr,
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(x_ptr + offsets, mask=mask)
        absmax_inv = tl.load(absmax_inv_ptr)
        output = tl.libdevice.llrint(127.0 * (x * absmax_inv))
        tl.store(output_ptr + offsets, output, mask=mask)

    def quantize_global(x: torch.Tensor):
        absmax = x.abs().max().unsqueeze(0)
        absmax_inv = 1.0 / absmax
        output = torch.empty(*x.shape, device="cuda", dtype=torch.int8)
        assert x.is_cuda and output.is_cuda
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        _quantize_global[grid](x, absmax_inv, output, n_elements)
        return output, absmax

    # global quantize and transpose
    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
            # ...
        ],
        key=["M", "N"],
    )
    @triton.jit
    def _quantize_global_transpose(
        A,
        absmax_inv_ptr,
        B,
        stride_am,
        stride_an,
        stride_bn,
        stride_bm,
        M,
        N,
        BLOCK_M: tl.constexpr,
        BLOCK_N: tl.constexpr,
        GROUP_M: tl.constexpr,
    ):
        pid = tl.program_id(0)
        grid_m = (M + BLOCK_M - 1) // BLOCK_M
        grid_n = (N + BLOCK_N - 1) // BLOCK_N

        width = GROUP_M * grid_n
        group_id = pid // width
        group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
        pid_m = group_id * GROUP_M + (pid % group_size)
        pid_n = (pid % width) // group_size

        rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
        mask = (rm < M)[:, None] & (rn < N)[None, :]
        a = tl.load(A, mask=mask)
        absmax_inv = tl.load(absmax_inv_ptr)

        # rematerialize to save registers
        rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
        mask = (rm < M)[:, None] & (rn < N)[None, :]

        output = tl.libdevice.llrint(127.0 * (a * absmax_inv))

        tl.store(B, output, mask=mask)

    def quantize_global_transpose(input):
        absmax = input.abs().max().unsqueeze(0)
        absmax_inv = 1.0 / absmax
        M, N = input.shape
        out = torch.empty(N, M, device="cuda", dtype=torch.int8)

        assert out.size(0) == N and out.size(1) == M
        assert input.stride(0) == 1 or input.stride(1) == 1
        assert out.stride(0) == 1 or out.stride(1) == 1

        grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
        _quantize_global_transpose[grid](
            input,
            absmax_inv,
            out,
            input.stride(0),
            input.stride(1),
            out.stride(0),
            out.stride(1),
            M,
            N,
        )
        return out, absmax