File size: 1,409 Bytes
89c278d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from transformers import PreTrainedModel

from ..utils import torch_gc


class CPUTextEncoderWrapper(PreTrainedModel):
    def __init__(self, text_encoder, torch_dtype):
        super().__init__(text_encoder.config)
        self.config = text_encoder.config
        self._device = text_encoder.device
        # cpu not support float16
        self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
        self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
        self.torch_dtype = torch_dtype
        del text_encoder
        torch_gc()

    def __call__(self, x, **kwargs):
        input_device = x.device
        original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
        for k, v in original_output.items():
            if isinstance(v, tuple):
                original_output[k] = [
                    v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
                ]
            else:
                original_output[k] = v.to(input_device).to(self.torch_dtype)
        return original_output

    @property
    def dtype(self):
        return self.torch_dtype

    @property
    def device(self) -> torch.device:
        """
        `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
        device).
        """
        return self._device