davidvgilmore commited on
Commit
f741e81
·
verified ·
1 Parent(s): 0b55cb7

Upload hy3dgen/texgen/custom_rasterizer/custom_rasterizer/io_glb.py with huggingface_hub

Browse files
hy3dgen/texgen/custom_rasterizer/custom_rasterizer/io_glb.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+ import base64
26
+ import io
27
+ import os
28
+
29
+ import numpy as np
30
+ from PIL import Image as PILImage
31
+ from pygltflib import GLTF2
32
+ from scipy.spatial.transform import Rotation as R
33
+
34
+
35
+ # Function to extract buffer data
36
+ def get_buffer_data(gltf, buffer_view):
37
+ buffer = gltf.buffers[buffer_view.buffer]
38
+ buffer_data = gltf.get_data_from_buffer_uri(buffer.uri)
39
+ byte_offset = buffer_view.byteOffset if buffer_view.byteOffset else 0
40
+ byte_length = buffer_view.byteLength
41
+ return buffer_data[byte_offset:byte_offset + byte_length]
42
+
43
+
44
+ # Function to extract attribute data
45
+ def get_attribute_data(gltf, accessor_index):
46
+ accessor = gltf.accessors[accessor_index]
47
+ buffer_view = gltf.bufferViews[accessor.bufferView]
48
+ buffer_data = get_buffer_data(gltf, buffer_view)
49
+
50
+ comptype = {5120: np.int8, 5121: np.uint8, 5122: np.int16, 5123: np.uint16, 5125: np.uint32, 5126: np.float32}
51
+ dtype = comptype[accessor.componentType]
52
+
53
+ t2n = {'SCALAR': 1, 'VEC2': 2, 'VEC3': 3, 'VEC4': 4, 'MAT2': 4, 'MAT3': 9, 'MAT4': 16}
54
+ num_components = t2n[accessor.type]
55
+
56
+ # Calculate the correct slice of data
57
+ byte_offset = accessor.byteOffset if accessor.byteOffset else 0
58
+ byte_stride = buffer_view.byteStride if buffer_view.byteStride else num_components * np.dtype(dtype).itemsize
59
+ count = accessor.count
60
+
61
+ # Extract the attribute data
62
+ attribute_data = np.zeros((count, num_components), dtype=dtype)
63
+ for i in range(count):
64
+ start = byte_offset + i * byte_stride
65
+ end = start + num_components * np.dtype(dtype).itemsize
66
+ attribute_data[i] = np.frombuffer(buffer_data[start:end], dtype=dtype)
67
+
68
+ return attribute_data
69
+
70
+
71
+ # Function to extract image data
72
+ def get_image_data(gltf, image, folder):
73
+ if image.uri:
74
+ if image.uri.startswith('data:'):
75
+ # Data URI
76
+ header, encoded = image.uri.split(',', 1)
77
+ data = base64.b64decode(encoded)
78
+ else:
79
+ # External file
80
+ fn = image.uri
81
+ if not os.path.isabs(fn):
82
+ fn = folder + '/' + fn
83
+ with open(fn, 'rb') as f:
84
+ data = f.read()
85
+ else:
86
+ buffer_view = gltf.bufferViews[image.bufferView]
87
+ data = get_buffer_data(gltf, buffer_view)
88
+ return data
89
+
90
+
91
+ # Function to convert triangle strip to triangles
92
+ def convert_triangle_strip_to_triangles(indices):
93
+ triangles = []
94
+ for i in range(len(indices) - 2):
95
+ if i % 2 == 0:
96
+ triangles.append([indices[i], indices[i + 1], indices[i + 2]])
97
+ else:
98
+ triangles.append([indices[i], indices[i + 2], indices[i + 1]])
99
+ return np.array(triangles).reshape(-1, 3)
100
+
101
+
102
+ # Function to convert triangle fan to triangles
103
+ def convert_triangle_fan_to_triangles(indices):
104
+ triangles = []
105
+ for i in range(1, len(indices) - 1):
106
+ triangles.append([indices[0], indices[i], indices[i + 1]])
107
+ return np.array(triangles).reshape(-1, 3)
108
+
109
+
110
+ # Function to get the transformation matrix from a node
111
+ def get_node_transform(node):
112
+ if node.matrix:
113
+ return np.array(node.matrix).reshape(4, 4).T
114
+ else:
115
+ T = np.eye(4)
116
+ if node.translation:
117
+ T[:3, 3] = node.translation
118
+ if node.rotation:
119
+ R_mat = R.from_quat(node.rotation).as_matrix()
120
+ T[:3, :3] = R_mat
121
+ if node.scale:
122
+ S = np.diag(node.scale + [1])
123
+ T = T @ S
124
+ return T
125
+
126
+
127
+ def get_world_transform(gltf, node_index, parents, world_transforms):
128
+ if parents[node_index] == -2:
129
+ return world_transforms[node_index]
130
+
131
+ node = gltf.nodes[node_index]
132
+ if parents[node_index] == -1:
133
+ world_transforms[node_index] = get_node_transform(node)
134
+ parents[node_index] = -2
135
+ return world_transforms[node_index]
136
+
137
+ parent_index = parents[node_index]
138
+ parent_transform = get_world_transform(gltf, parent_index, parents, world_transforms)
139
+ world_transforms[node_index] = parent_transform @ get_node_transform(node)
140
+ parents[node_index] = -2
141
+ return world_transforms[node_index]
142
+
143
+
144
+ def LoadGlb(path):
145
+ # Load the GLB file using pygltflib
146
+ gltf = GLTF2().load(path)
147
+
148
+ primitives = []
149
+ images = {}
150
+ # Iterate through the meshes in the GLB file
151
+
152
+ world_transforms = [np.identity(4) for i in range(len(gltf.nodes))]
153
+ parents = [-1 for i in range(len(gltf.nodes))]
154
+ for node_index, node in enumerate(gltf.nodes):
155
+ for idx in node.children:
156
+ parents[idx] = node_index
157
+ # for i in range(len(gltf.nodes)):
158
+ # get_world_transform(gltf, i, parents, world_transform)
159
+
160
+ for node_index, node in enumerate(gltf.nodes):
161
+ if node.mesh is not None:
162
+ world_transform = get_world_transform(gltf, node_index, parents, world_transforms)
163
+ # Iterate through the primitives in the mesh
164
+ mesh = gltf.meshes[node.mesh]
165
+ for primitive in mesh.primitives:
166
+ # Access the attributes of the primitive
167
+ attributes = primitive.attributes.__dict__
168
+ mode = primitive.mode if primitive.mode is not None else 4 # Default to TRIANGLES
169
+ result = {}
170
+ if primitive.indices is not None:
171
+ indices = get_attribute_data(gltf, primitive.indices)
172
+ if mode == 4: # TRIANGLES
173
+ face_indices = indices.reshape(-1, 3)
174
+ elif mode == 5: # TRIANGLE_STRIP
175
+ face_indices = convert_triangle_strip_to_triangles(indices)
176
+ elif mode == 6: # TRIANGLE_FAN
177
+ face_indices = convert_triangle_fan_to_triangles(indices)
178
+ else:
179
+ continue
180
+ result['F'] = face_indices
181
+
182
+ # Extract vertex positions
183
+ if 'POSITION' in attributes and attributes['POSITION'] is not None:
184
+ positions = get_attribute_data(gltf, attributes['POSITION'])
185
+ # Apply the world transformation to the positions
186
+ positions_homogeneous = np.hstack([positions, np.ones((positions.shape[0], 1))])
187
+ transformed_positions = (world_transform @ positions_homogeneous.T).T[:, :3]
188
+ result['V'] = transformed_positions
189
+
190
+ # Extract vertex colors
191
+ if 'COLOR_0' in attributes and attributes['COLOR_0'] is not None:
192
+ colors = get_attribute_data(gltf, attributes['COLOR_0'])
193
+ if colors.shape[-1] > 3:
194
+ colors = colors[..., :3]
195
+ result['VC'] = colors
196
+
197
+ # Extract UVs
198
+ if 'TEXCOORD_0' in attributes and not attributes['TEXCOORD_0'] is None:
199
+ uvs = get_attribute_data(gltf, attributes['TEXCOORD_0'])
200
+ result['UV'] = uvs
201
+
202
+ if primitive.material is not None:
203
+ material = gltf.materials[primitive.material]
204
+ if material.pbrMetallicRoughness is not None and material.pbrMetallicRoughness.baseColorTexture is not None:
205
+ texture_index = material.pbrMetallicRoughness.baseColorTexture.index
206
+ texture = gltf.textures[texture_index]
207
+ image_index = texture.source
208
+ if not image_index in images:
209
+ image = gltf.images[image_index]
210
+ image_data = get_image_data(gltf, image, os.path.dirname(path))
211
+ pil_image = PILImage.open(io.BytesIO(image_data))
212
+ if pil_image.mode != 'RGB':
213
+ pil_image = pil_image.convert('RGB')
214
+ images[image_index] = pil_image
215
+ result['TEX'] = image_index
216
+ elif material.emissiveTexture is not None:
217
+ texture_index = material.emissiveTexture.index
218
+ texture = gltf.textures[texture_index]
219
+ image_index = texture.source
220
+ if not image_index in images:
221
+ image = gltf.images[image_index]
222
+ image_data = get_image_data(gltf, image, os.path.dirname(path))
223
+ pil_image = PILImage.open(io.BytesIO(image_data))
224
+ if pil_image.mode != 'RGB':
225
+ pil_image = pil_image.convert('RGB')
226
+ images[image_index] = pil_image
227
+ result['TEX'] = image_index
228
+ else:
229
+ if material.pbrMetallicRoughness is not None:
230
+ base_color = material.pbrMetallicRoughness.baseColorFactor
231
+ else:
232
+ base_color = np.array([0.8, 0.8, 0.8], dtype=np.float32)
233
+ result['MC'] = base_color
234
+
235
+ primitives.append(result)
236
+
237
+ return primitives, images
238
+
239
+
240
+ def RotatePrimitives(primitives, transform):
241
+ for i in range(len(primitives)):
242
+ if 'V' in primitives[i]:
243
+ primitives[i]['V'] = primitives[i]['V'] @ transform.T
244
+
245
+
246
+ if __name__ == '__main__':
247
+ path = 'data/test.glb'
248
+ LoadGlb(path)