File size: 2,328 Bytes
6d63e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from typing import TypeVar

import torch
from torch import nn

import bitsandbytes as bnb

T = TypeVar("T", bound="torch.nn.Module")


class LinearFP8Mixed(nn.Linear):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__(input_features, output_features, bias)
        self.bw_code = None
        self.fw_code = None
        array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
        for i, k in enumerate(array):
            if input_features > array[i + 1]:
                self.bsz = k
                break
        for i, k in enumerate(array):
            if output_features > array[i + 1]:
                self.bsz2 = k
                break

    def forward(self, x: torch.Tensor):
        if self.fw_code is None:
            self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
            self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)

        out = bnb.research.matmul_fp8_mixed(
            x,
            self.weight.t(),
            fw_code=self.fw_code,
            bw_code=self.bw_code,
            bsz=self.bsz,
            bsz2=self.bsz2,
        )
        if self.bias is not None:
            out += self.bias

        return out


class LinearFP8Global(nn.Linear):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__(input_features, output_features, bias)
        self.bw_code = None
        self.fw_code = None
        array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
        for i, k in enumerate(array):
            if input_features > array[i + 1]:
                self.bsz = k
                break
        for i, k in enumerate(array):
            if output_features > array[i + 1]:
                self.bsz2 = k
                break

    def forward(self, x: torch.Tensor):
        if self.fw_code is None:
            self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
            self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)

        out = bnb.matmul_fp8_global(
            x,
            self.weight.t(),
            fw_code=self.fw_code,
            bw_code=self.bw_code,
            bsz=self.bsz,
            bsz2=self.bsz2,
        )
        if self.bias is not None:
            out += self.bias

        return out