davidvgilmore commited on
Commit
2181970
·
verified ·
1 Parent(s): ac146c5

Upload hy3dgen/shapegen/models/conditioner.py with huggingface_hub

Browse files
hy3dgen/shapegen/models/conditioner.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ from torchvision import transforms
28
+ from transformers import (
29
+ CLIPVisionModelWithProjection,
30
+ CLIPVisionConfig,
31
+ Dinov2Model,
32
+ Dinov2Config,
33
+ )
34
+
35
+
36
+ class ImageEncoder(nn.Module):
37
+ def __init__(
38
+ self,
39
+ version=None,
40
+ config=None,
41
+ use_cls_token=True,
42
+ image_size=224,
43
+ **kwargs,
44
+ ):
45
+ super().__init__()
46
+
47
+ if config is None:
48
+ self.model = self.MODEL_CLASS.from_pretrained(version)
49
+ else:
50
+ self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))
51
+ self.model.eval()
52
+ self.model.requires_grad_(False)
53
+ self.use_cls_token = use_cls_token
54
+ self.size = image_size // 14
55
+ self.num_patches = (image_size // 14) ** 2
56
+ if self.use_cls_token:
57
+ self.num_patches += 1
58
+
59
+ self.transform = transforms.Compose(
60
+ [
61
+ transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True),
62
+ transforms.CenterCrop(image_size),
63
+ transforms.Normalize(
64
+ mean=self.mean,
65
+ std=self.std,
66
+ ),
67
+ ]
68
+ )
69
+
70
+ def forward(self, image, mask=None, value_range=(-1, 1)):
71
+ if value_range is not None:
72
+ low, high = value_range
73
+ image = (image - low) / (high - low)
74
+
75
+ image = image.to(self.model.device, dtype=self.model.dtype)
76
+ inputs = self.transform(image)
77
+ outputs = self.model(inputs)
78
+
79
+ last_hidden_state = outputs.last_hidden_state
80
+ if not self.use_cls_token:
81
+ last_hidden_state = last_hidden_state[:, 1:, :]
82
+
83
+ return last_hidden_state
84
+
85
+ def unconditional_embedding(self, batch_size):
86
+ device = next(self.model.parameters()).device
87
+ dtype = next(self.model.parameters()).dtype
88
+ zero = torch.zeros(
89
+ batch_size,
90
+ self.num_patches,
91
+ self.model.config.hidden_size,
92
+ device=device,
93
+ dtype=dtype,
94
+ )
95
+
96
+ return zero
97
+
98
+
99
+ class CLIPImageEncoder(ImageEncoder):
100
+ MODEL_CLASS = CLIPVisionModelWithProjection
101
+ MODEL_CONFIG_CLASS = CLIPVisionConfig
102
+ mean = [0.48145466, 0.4578275, 0.40821073]
103
+ std = [0.26862954, 0.26130258, 0.27577711]
104
+
105
+
106
+ class DinoImageEncoder(ImageEncoder):
107
+ MODEL_CLASS = Dinov2Model
108
+ MODEL_CONFIG_CLASS = Dinov2Config
109
+ mean = [0.485, 0.456, 0.406]
110
+ std = [0.229, 0.224, 0.225]
111
+
112
+
113
+ def build_image_encoder(config):
114
+ if config['type'] == 'CLIPImageEncoder':
115
+ return CLIPImageEncoder(**config['kwargs'])
116
+ elif config['type'] == 'DinoImageEncoder':
117
+ return DinoImageEncoder(**config['kwargs'])
118
+ else:
119
+ raise ValueError(f'Unknown image encoder type: {config["type"]}')
120
+
121
+
122
+ class DualImageEncoder(nn.Module):
123
+ def __init__(
124
+ self,
125
+ main_image_encoder,
126
+ additional_image_encoder,
127
+ ):
128
+ super().__init__()
129
+ self.main_image_encoder = build_image_encoder(main_image_encoder)
130
+ self.additional_image_encoder = build_image_encoder(additional_image_encoder)
131
+
132
+ def forward(self, image, mask=None):
133
+ outputs = {
134
+ 'main': self.main_image_encoder(image, mask=mask),
135
+ 'additional': self.additional_image_encoder(image, mask=mask),
136
+ }
137
+ return outputs
138
+
139
+ def unconditional_embedding(self, batch_size):
140
+ outputs = {
141
+ 'main': self.main_image_encoder.unconditional_embedding(batch_size),
142
+ 'additional': self.additional_image_encoder.unconditional_embedding(batch_size),
143
+ }
144
+ return outputs
145
+
146
+
147
+ class SingleImageEncoder(nn.Module):
148
+ def __init__(
149
+ self,
150
+ main_image_encoder,
151
+ ):
152
+ super().__init__()
153
+ self.main_image_encoder = build_image_encoder(main_image_encoder)
154
+
155
+ def forward(self, image, mask=None):
156
+ outputs = {
157
+ 'main': self.main_image_encoder(image, mask=mask),
158
+ }
159
+ return outputs
160
+
161
+ def unconditional_embedding(self, batch_size):
162
+ outputs = {
163
+ 'main': self.main_image_encoder.unconditional_embedding(batch_size),
164
+ }
165
+ return outputs