Spaces:
Runtime error
Runtime error
Upload hy3dgen/shapegen/models/vae.py with huggingface_hub
Browse files- hy3dgen/shapegen/models/vae.py +636 -0
hy3dgen/shapegen/models/vae.py
ADDED
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Open Source Model Licensed under the Apache License Version 2.0
|
2 |
+
# and Other Licenses of the Third-Party Components therein:
|
3 |
+
# The below Model in this distribution may have been modified by THL A29 Limited
|
4 |
+
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
5 |
+
|
6 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
7 |
+
# The below software and/or models in this distribution may have been
|
8 |
+
# modified by THL A29 Limited ("Tencent Modifications").
|
9 |
+
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
10 |
+
|
11 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
12 |
+
# except for the third-party components listed below.
|
13 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
14 |
+
# in the repsective licenses of these third-party components.
|
15 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
16 |
+
# components and must ensure that the usage of the third party components adheres to
|
17 |
+
# all relevant laws and regulations.
|
18 |
+
|
19 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
20 |
+
# their software and algorithms, including trained model weights, parameters (including
|
21 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
22 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
+
|
25 |
+
from typing import Tuple, List, Union, Optional
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
import torch.nn.functional as F
|
31 |
+
from einops import rearrange, repeat
|
32 |
+
from skimage import measure
|
33 |
+
from tqdm import tqdm
|
34 |
+
|
35 |
+
|
36 |
+
class FourierEmbedder(nn.Module):
|
37 |
+
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
38 |
+
each feature dimension of `x[..., i]` into:
|
39 |
+
[
|
40 |
+
sin(x[..., i]),
|
41 |
+
sin(f_1*x[..., i]),
|
42 |
+
sin(f_2*x[..., i]),
|
43 |
+
...
|
44 |
+
sin(f_N * x[..., i]),
|
45 |
+
cos(x[..., i]),
|
46 |
+
cos(f_1*x[..., i]),
|
47 |
+
cos(f_2*x[..., i]),
|
48 |
+
...
|
49 |
+
cos(f_N * x[..., i]),
|
50 |
+
x[..., i] # only present if include_input is True.
|
51 |
+
], here f_i is the frequency.
|
52 |
+
|
53 |
+
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
54 |
+
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
55 |
+
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
56 |
+
|
57 |
+
Args:
|
58 |
+
num_freqs (int): the number of frequencies, default is 6;
|
59 |
+
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
60 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
61 |
+
input_dim (int): the input dimension, default is 3;
|
62 |
+
include_input (bool): include the input tensor or not, default is True.
|
63 |
+
|
64 |
+
Attributes:
|
65 |
+
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
66 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
67 |
+
|
68 |
+
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
69 |
+
otherwise, it is input_dim * num_freqs * 2.
|
70 |
+
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self,
|
74 |
+
num_freqs: int = 6,
|
75 |
+
logspace: bool = True,
|
76 |
+
input_dim: int = 3,
|
77 |
+
include_input: bool = True,
|
78 |
+
include_pi: bool = True) -> None:
|
79 |
+
|
80 |
+
"""The initialization"""
|
81 |
+
|
82 |
+
super().__init__()
|
83 |
+
|
84 |
+
if logspace:
|
85 |
+
frequencies = 2.0 ** torch.arange(
|
86 |
+
num_freqs,
|
87 |
+
dtype=torch.float32
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
frequencies = torch.linspace(
|
91 |
+
1.0,
|
92 |
+
2.0 ** (num_freqs - 1),
|
93 |
+
num_freqs,
|
94 |
+
dtype=torch.float32
|
95 |
+
)
|
96 |
+
|
97 |
+
if include_pi:
|
98 |
+
frequencies *= torch.pi
|
99 |
+
|
100 |
+
self.register_buffer("frequencies", frequencies, persistent=False)
|
101 |
+
self.include_input = include_input
|
102 |
+
self.num_freqs = num_freqs
|
103 |
+
|
104 |
+
self.out_dim = self.get_dims(input_dim)
|
105 |
+
|
106 |
+
def get_dims(self, input_dim):
|
107 |
+
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
108 |
+
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
109 |
+
|
110 |
+
return out_dim
|
111 |
+
|
112 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
113 |
+
""" Forward process.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
x: tensor of shape [..., dim]
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
120 |
+
where temp is 1 if include_input is True and 0 otherwise.
|
121 |
+
"""
|
122 |
+
|
123 |
+
if self.num_freqs > 0:
|
124 |
+
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
|
125 |
+
if self.include_input:
|
126 |
+
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
127 |
+
else:
|
128 |
+
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
129 |
+
else:
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class DropPath(nn.Module):
|
134 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
135 |
+
"""
|
136 |
+
|
137 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
138 |
+
super(DropPath, self).__init__()
|
139 |
+
self.drop_prob = drop_prob
|
140 |
+
self.scale_by_keep = scale_by_keep
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
144 |
+
|
145 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
146 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
147 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
148 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
149 |
+
'survival rate' as the argument.
|
150 |
+
|
151 |
+
"""
|
152 |
+
if self.drop_prob == 0. or not self.training:
|
153 |
+
return x
|
154 |
+
keep_prob = 1 - self.drop_prob
|
155 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
156 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
157 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
158 |
+
random_tensor.div_(keep_prob)
|
159 |
+
return x * random_tensor
|
160 |
+
|
161 |
+
def extra_repr(self):
|
162 |
+
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
163 |
+
|
164 |
+
|
165 |
+
class MLP(nn.Module):
|
166 |
+
def __init__(
|
167 |
+
self, *,
|
168 |
+
width: int,
|
169 |
+
output_width: int = None,
|
170 |
+
drop_path_rate: float = 0.0
|
171 |
+
):
|
172 |
+
super().__init__()
|
173 |
+
self.width = width
|
174 |
+
self.c_fc = nn.Linear(width, width * 4)
|
175 |
+
self.c_proj = nn.Linear(width * 4, output_width if output_width is not None else width)
|
176 |
+
self.gelu = nn.GELU()
|
177 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
181 |
+
|
182 |
+
|
183 |
+
class QKVMultiheadCrossAttention(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
*,
|
187 |
+
heads: int,
|
188 |
+
n_data: Optional[int] = None,
|
189 |
+
width=None,
|
190 |
+
qk_norm=False,
|
191 |
+
norm_layer=nn.LayerNorm
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
self.heads = heads
|
195 |
+
self.n_data = n_data
|
196 |
+
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
197 |
+
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
198 |
+
|
199 |
+
def forward(self, q, kv):
|
200 |
+
_, n_ctx, _ = q.shape
|
201 |
+
bs, n_data, width = kv.shape
|
202 |
+
attn_ch = width // self.heads // 2
|
203 |
+
q = q.view(bs, n_ctx, self.heads, -1)
|
204 |
+
kv = kv.view(bs, n_data, self.heads, -1)
|
205 |
+
k, v = torch.split(kv, attn_ch, dim=-1)
|
206 |
+
|
207 |
+
q = self.q_norm(q)
|
208 |
+
k = self.k_norm(k)
|
209 |
+
|
210 |
+
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
211 |
+
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
212 |
+
|
213 |
+
return out
|
214 |
+
|
215 |
+
|
216 |
+
class MultiheadCrossAttention(nn.Module):
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
*,
|
220 |
+
width: int,
|
221 |
+
heads: int,
|
222 |
+
qkv_bias: bool = True,
|
223 |
+
n_data: Optional[int] = None,
|
224 |
+
data_width: Optional[int] = None,
|
225 |
+
norm_layer=nn.LayerNorm,
|
226 |
+
qk_norm: bool = False
|
227 |
+
):
|
228 |
+
super().__init__()
|
229 |
+
self.n_data = n_data
|
230 |
+
self.width = width
|
231 |
+
self.heads = heads
|
232 |
+
self.data_width = width if data_width is None else data_width
|
233 |
+
self.c_q = nn.Linear(width, width, bias=qkv_bias)
|
234 |
+
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
|
235 |
+
self.c_proj = nn.Linear(width, width)
|
236 |
+
self.attention = QKVMultiheadCrossAttention(
|
237 |
+
heads=heads,
|
238 |
+
n_data=n_data,
|
239 |
+
width=width,
|
240 |
+
norm_layer=norm_layer,
|
241 |
+
qk_norm=qk_norm
|
242 |
+
)
|
243 |
+
|
244 |
+
def forward(self, x, data):
|
245 |
+
x = self.c_q(x)
|
246 |
+
data = self.c_kv(data)
|
247 |
+
x = self.attention(x, data)
|
248 |
+
x = self.c_proj(x)
|
249 |
+
return x
|
250 |
+
|
251 |
+
|
252 |
+
class ResidualCrossAttentionBlock(nn.Module):
|
253 |
+
def __init__(
|
254 |
+
self,
|
255 |
+
*,
|
256 |
+
n_data: Optional[int] = None,
|
257 |
+
width: int,
|
258 |
+
heads: int,
|
259 |
+
data_width: Optional[int] = None,
|
260 |
+
qkv_bias: bool = True,
|
261 |
+
norm_layer=nn.LayerNorm,
|
262 |
+
qk_norm: bool = False
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
|
266 |
+
if data_width is None:
|
267 |
+
data_width = width
|
268 |
+
|
269 |
+
self.attn = MultiheadCrossAttention(
|
270 |
+
n_data=n_data,
|
271 |
+
width=width,
|
272 |
+
heads=heads,
|
273 |
+
data_width=data_width,
|
274 |
+
qkv_bias=qkv_bias,
|
275 |
+
norm_layer=norm_layer,
|
276 |
+
qk_norm=qk_norm
|
277 |
+
)
|
278 |
+
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
279 |
+
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
280 |
+
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
281 |
+
self.mlp = MLP(width=width)
|
282 |
+
|
283 |
+
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
284 |
+
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
285 |
+
x = x + self.mlp(self.ln_3(x))
|
286 |
+
return x
|
287 |
+
|
288 |
+
|
289 |
+
class QKVMultiheadAttention(nn.Module):
|
290 |
+
def __init__(
|
291 |
+
self,
|
292 |
+
*,
|
293 |
+
heads: int,
|
294 |
+
n_ctx: int,
|
295 |
+
width=None,
|
296 |
+
qk_norm=False,
|
297 |
+
norm_layer=nn.LayerNorm
|
298 |
+
):
|
299 |
+
super().__init__()
|
300 |
+
self.heads = heads
|
301 |
+
self.n_ctx = n_ctx
|
302 |
+
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
303 |
+
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
304 |
+
|
305 |
+
def forward(self, qkv):
|
306 |
+
bs, n_ctx, width = qkv.shape
|
307 |
+
attn_ch = width // self.heads // 3
|
308 |
+
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
309 |
+
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
310 |
+
|
311 |
+
q = self.q_norm(q)
|
312 |
+
k = self.k_norm(k)
|
313 |
+
|
314 |
+
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
315 |
+
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
316 |
+
return out
|
317 |
+
|
318 |
+
|
319 |
+
class MultiheadAttention(nn.Module):
|
320 |
+
def __init__(
|
321 |
+
self,
|
322 |
+
*,
|
323 |
+
n_ctx: int,
|
324 |
+
width: int,
|
325 |
+
heads: int,
|
326 |
+
qkv_bias: bool,
|
327 |
+
norm_layer=nn.LayerNorm,
|
328 |
+
qk_norm: bool = False,
|
329 |
+
drop_path_rate: float = 0.0
|
330 |
+
):
|
331 |
+
super().__init__()
|
332 |
+
self.n_ctx = n_ctx
|
333 |
+
self.width = width
|
334 |
+
self.heads = heads
|
335 |
+
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
|
336 |
+
self.c_proj = nn.Linear(width, width)
|
337 |
+
self.attention = QKVMultiheadAttention(
|
338 |
+
heads=heads,
|
339 |
+
n_ctx=n_ctx,
|
340 |
+
width=width,
|
341 |
+
norm_layer=norm_layer,
|
342 |
+
qk_norm=qk_norm
|
343 |
+
)
|
344 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
345 |
+
|
346 |
+
def forward(self, x):
|
347 |
+
x = self.c_qkv(x)
|
348 |
+
x = self.attention(x)
|
349 |
+
x = self.drop_path(self.c_proj(x))
|
350 |
+
return x
|
351 |
+
|
352 |
+
|
353 |
+
class ResidualAttentionBlock(nn.Module):
|
354 |
+
def __init__(
|
355 |
+
self,
|
356 |
+
*,
|
357 |
+
n_ctx: int,
|
358 |
+
width: int,
|
359 |
+
heads: int,
|
360 |
+
qkv_bias: bool = True,
|
361 |
+
norm_layer=nn.LayerNorm,
|
362 |
+
qk_norm: bool = False,
|
363 |
+
drop_path_rate: float = 0.0,
|
364 |
+
):
|
365 |
+
super().__init__()
|
366 |
+
self.attn = MultiheadAttention(
|
367 |
+
n_ctx=n_ctx,
|
368 |
+
width=width,
|
369 |
+
heads=heads,
|
370 |
+
qkv_bias=qkv_bias,
|
371 |
+
norm_layer=norm_layer,
|
372 |
+
qk_norm=qk_norm,
|
373 |
+
drop_path_rate=drop_path_rate
|
374 |
+
)
|
375 |
+
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
376 |
+
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
377 |
+
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
378 |
+
|
379 |
+
def forward(self, x: torch.Tensor):
|
380 |
+
x = x + self.attn(self.ln_1(x))
|
381 |
+
x = x + self.mlp(self.ln_2(x))
|
382 |
+
return x
|
383 |
+
|
384 |
+
|
385 |
+
class Transformer(nn.Module):
|
386 |
+
def __init__(
|
387 |
+
self,
|
388 |
+
*,
|
389 |
+
n_ctx: int,
|
390 |
+
width: int,
|
391 |
+
layers: int,
|
392 |
+
heads: int,
|
393 |
+
qkv_bias: bool = True,
|
394 |
+
norm_layer=nn.LayerNorm,
|
395 |
+
qk_norm: bool = False,
|
396 |
+
drop_path_rate: float = 0.0
|
397 |
+
):
|
398 |
+
super().__init__()
|
399 |
+
self.n_ctx = n_ctx
|
400 |
+
self.width = width
|
401 |
+
self.layers = layers
|
402 |
+
self.resblocks = nn.ModuleList(
|
403 |
+
[
|
404 |
+
ResidualAttentionBlock(
|
405 |
+
n_ctx=n_ctx,
|
406 |
+
width=width,
|
407 |
+
heads=heads,
|
408 |
+
qkv_bias=qkv_bias,
|
409 |
+
norm_layer=norm_layer,
|
410 |
+
qk_norm=qk_norm,
|
411 |
+
drop_path_rate=drop_path_rate
|
412 |
+
)
|
413 |
+
for _ in range(layers)
|
414 |
+
]
|
415 |
+
)
|
416 |
+
|
417 |
+
def forward(self, x: torch.Tensor):
|
418 |
+
for block in self.resblocks:
|
419 |
+
x = block(x)
|
420 |
+
return x
|
421 |
+
|
422 |
+
|
423 |
+
class CrossAttentionDecoder(nn.Module):
|
424 |
+
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
*,
|
428 |
+
num_latents: int,
|
429 |
+
out_channels: int,
|
430 |
+
fourier_embedder: FourierEmbedder,
|
431 |
+
width: int,
|
432 |
+
heads: int,
|
433 |
+
qkv_bias: bool = True,
|
434 |
+
qk_norm: bool = False,
|
435 |
+
label_type: str = "binary"
|
436 |
+
):
|
437 |
+
super().__init__()
|
438 |
+
|
439 |
+
self.fourier_embedder = fourier_embedder
|
440 |
+
|
441 |
+
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
|
442 |
+
|
443 |
+
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
444 |
+
n_data=num_latents,
|
445 |
+
width=width,
|
446 |
+
heads=heads,
|
447 |
+
qkv_bias=qkv_bias,
|
448 |
+
qk_norm=qk_norm
|
449 |
+
)
|
450 |
+
|
451 |
+
self.ln_post = nn.LayerNorm(width)
|
452 |
+
self.output_proj = nn.Linear(width, out_channels)
|
453 |
+
self.label_type = label_type
|
454 |
+
|
455 |
+
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
456 |
+
queries = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
457 |
+
x = self.cross_attn_decoder(queries, latents)
|
458 |
+
x = self.ln_post(x)
|
459 |
+
occ = self.output_proj(x)
|
460 |
+
return occ
|
461 |
+
|
462 |
+
|
463 |
+
def generate_dense_grid_points(bbox_min: np.ndarray,
|
464 |
+
bbox_max: np.ndarray,
|
465 |
+
octree_depth: int,
|
466 |
+
indexing: str = "ij",
|
467 |
+
octree_resolution: int = None,
|
468 |
+
):
|
469 |
+
length = bbox_max - bbox_min
|
470 |
+
num_cells = np.exp2(octree_depth)
|
471 |
+
if octree_resolution is not None:
|
472 |
+
num_cells = octree_resolution
|
473 |
+
|
474 |
+
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
475 |
+
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
476 |
+
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
477 |
+
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
478 |
+
xyz = np.stack((xs, ys, zs), axis=-1)
|
479 |
+
xyz = xyz.reshape(-1, 3)
|
480 |
+
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
481 |
+
|
482 |
+
return xyz, grid_size, length
|
483 |
+
|
484 |
+
|
485 |
+
def center_vertices(vertices):
|
486 |
+
"""Translate the vertices so that bounding box is centered at zero."""
|
487 |
+
vert_min = vertices.min(dim=0)[0]
|
488 |
+
vert_max = vertices.max(dim=0)[0]
|
489 |
+
vert_center = 0.5 * (vert_min + vert_max)
|
490 |
+
return vertices - vert_center
|
491 |
+
|
492 |
+
|
493 |
+
class Latent2MeshOutput:
|
494 |
+
|
495 |
+
def __init__(self, mesh_v=None, mesh_f=None):
|
496 |
+
self.mesh_v = mesh_v
|
497 |
+
self.mesh_f = mesh_f
|
498 |
+
|
499 |
+
|
500 |
+
class ShapeVAE(nn.Module):
|
501 |
+
def __init__(
|
502 |
+
self,
|
503 |
+
*,
|
504 |
+
num_latents: int,
|
505 |
+
embed_dim: int,
|
506 |
+
width: int,
|
507 |
+
heads: int,
|
508 |
+
num_decoder_layers: int,
|
509 |
+
num_freqs: int = 8,
|
510 |
+
include_pi: bool = True,
|
511 |
+
qkv_bias: bool = True,
|
512 |
+
qk_norm: bool = False,
|
513 |
+
label_type: str = "binary",
|
514 |
+
drop_path_rate: float = 0.0,
|
515 |
+
scale_factor: float = 1.0,
|
516 |
+
):
|
517 |
+
super().__init__()
|
518 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
519 |
+
|
520 |
+
self.post_kl = nn.Linear(embed_dim, width)
|
521 |
+
|
522 |
+
self.transformer = Transformer(
|
523 |
+
n_ctx=num_latents,
|
524 |
+
width=width,
|
525 |
+
layers=num_decoder_layers,
|
526 |
+
heads=heads,
|
527 |
+
qkv_bias=qkv_bias,
|
528 |
+
qk_norm=qk_norm,
|
529 |
+
drop_path_rate=drop_path_rate
|
530 |
+
)
|
531 |
+
|
532 |
+
self.geo_decoder = CrossAttentionDecoder(
|
533 |
+
fourier_embedder=self.fourier_embedder,
|
534 |
+
out_channels=1,
|
535 |
+
num_latents=num_latents,
|
536 |
+
width=width,
|
537 |
+
heads=heads,
|
538 |
+
qkv_bias=qkv_bias,
|
539 |
+
qk_norm=qk_norm,
|
540 |
+
label_type=label_type,
|
541 |
+
)
|
542 |
+
|
543 |
+
self.scale_factor = scale_factor
|
544 |
+
self.latent_shape = (num_latents, embed_dim)
|
545 |
+
|
546 |
+
def forward(self, latents):
|
547 |
+
latents = self.post_kl(latents)
|
548 |
+
latents = self.transformer(latents)
|
549 |
+
return latents
|
550 |
+
|
551 |
+
@torch.no_grad()
|
552 |
+
def latents2mesh(
|
553 |
+
self,
|
554 |
+
latents: torch.FloatTensor,
|
555 |
+
bounds: Union[Tuple[float], List[float], float] = 1.1,
|
556 |
+
octree_depth: int = 7,
|
557 |
+
num_chunks: int = 10000,
|
558 |
+
mc_level: float = -1 / 512,
|
559 |
+
octree_resolution: int = None,
|
560 |
+
mc_algo: str = 'dmc',
|
561 |
+
):
|
562 |
+
device = latents.device
|
563 |
+
|
564 |
+
# 1. generate query points
|
565 |
+
if isinstance(bounds, float):
|
566 |
+
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
567 |
+
bbox_min = np.array(bounds[0:3])
|
568 |
+
bbox_max = np.array(bounds[3:6])
|
569 |
+
bbox_size = bbox_max - bbox_min
|
570 |
+
xyz_samples, grid_size, length = generate_dense_grid_points(
|
571 |
+
bbox_min=bbox_min,
|
572 |
+
bbox_max=bbox_max,
|
573 |
+
octree_depth=octree_depth,
|
574 |
+
octree_resolution=octree_resolution,
|
575 |
+
indexing="ij"
|
576 |
+
)
|
577 |
+
xyz_samples = torch.FloatTensor(xyz_samples)
|
578 |
+
|
579 |
+
# 2. latents to 3d volume
|
580 |
+
batch_logits = []
|
581 |
+
batch_size = latents.shape[0]
|
582 |
+
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
583 |
+
desc=f"MC Level {mc_level} Implicit Function:"):
|
584 |
+
queries = xyz_samples[start: start + num_chunks, :].to(device)
|
585 |
+
queries = queries.half()
|
586 |
+
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
587 |
+
|
588 |
+
logits = self.geo_decoder(batch_queries.to(latents.dtype), latents)
|
589 |
+
if mc_level == -1:
|
590 |
+
mc_level = 0
|
591 |
+
logits = torch.sigmoid(logits) * 2 - 1
|
592 |
+
print(f'Training with soft labels, inference with sigmoid and marching cubes level 0.')
|
593 |
+
batch_logits.append(logits)
|
594 |
+
grid_logits = torch.cat(batch_logits, dim=1)
|
595 |
+
grid_logits = grid_logits.view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float()
|
596 |
+
|
597 |
+
# 3. extract surface
|
598 |
+
outputs = []
|
599 |
+
for i in range(batch_size):
|
600 |
+
try:
|
601 |
+
if mc_algo == 'mc':
|
602 |
+
vertices, faces, normals, _ = measure.marching_cubes(
|
603 |
+
grid_logits[i].cpu().numpy(),
|
604 |
+
mc_level,
|
605 |
+
method="lewiner"
|
606 |
+
)
|
607 |
+
vertices = vertices / grid_size * bbox_size + bbox_min
|
608 |
+
elif mc_algo == 'dmc':
|
609 |
+
if not hasattr(self, 'dmc'):
|
610 |
+
try:
|
611 |
+
from diso import DiffDMC
|
612 |
+
except:
|
613 |
+
raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
|
614 |
+
self.dmc = DiffDMC(dtype=torch.float32).to(device)
|
615 |
+
octree_resolution = 2 ** octree_depth if octree_resolution is None else octree_resolution
|
616 |
+
sdf = -grid_logits[i] / octree_resolution
|
617 |
+
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
|
618 |
+
verts = center_vertices(verts)
|
619 |
+
vertices = verts.detach().cpu().numpy()
|
620 |
+
faces = faces.detach().cpu().numpy()[:, ::-1]
|
621 |
+
else:
|
622 |
+
raise ValueError(f"mc_algo {mc_algo} not supported.")
|
623 |
+
|
624 |
+
outputs.append(
|
625 |
+
Latent2MeshOutput(
|
626 |
+
mesh_v=vertices.astype(np.float32),
|
627 |
+
mesh_f=np.ascontiguousarray(faces)
|
628 |
+
)
|
629 |
+
)
|
630 |
+
|
631 |
+
except ValueError:
|
632 |
+
outputs.append(None)
|
633 |
+
except RuntimeError:
|
634 |
+
outputs.append(None)
|
635 |
+
|
636 |
+
return outputs
|