File size: 9,074 Bytes
ada4b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import trimesh
import numpy as np
from .data_utils import discretize, undiscretize


def patchified_mesh(mesh: trimesh.Trimesh, special_token = -2, fix_orient=True):
    sequence = []
    unvisited = np.full(len(mesh.faces), True)
    degrees = mesh.vertex_degree.copy()

    # with fix_orient=True, the normal would be correct.
    # but this may increase the difficulty for learning.
    if fix_orient:
        face_orient = {}
        for ind, face in enumerate(mesh.faces):
            v0, v1, v2 = face[0], face[1], face[2]
            face_orient['{}-{}-{}'.format(v0, v1, v2)] = True
            face_orient['{}-{}-{}'.format(v1, v2, v0)] = True
            face_orient['{}-{}-{}'.format(v2, v0, v1)] = True
            face_orient['{}-{}-{}'.format(v2, v1, v0)] = False
            face_orient['{}-{}-{}'.format(v1, v0, v2)] = False
            face_orient['{}-{}-{}'.format(v0, v2, v1)] = False

    while sum(unvisited):
        unvisited_faces = mesh.faces[unvisited]
        
        # select the patch center
        cur_face = unvisited_faces[0]
        max_deg_vertex_id = np.argmax(degrees[cur_face])
        max_deg_vertex = cur_face[max_deg_vertex_id]
        
        # find all connected faces
        selected_faces = []
        for face_idx in mesh.vertex_faces[max_deg_vertex]:
            if face_idx != -1 and unvisited[face_idx]:
                face = mesh.faces[face_idx]
                u, v = sorted([vertex for vertex in face if vertex != max_deg_vertex])
                selected_faces.append([u, v, face_idx])
                
        face_patch = set()
        selected_faces = sorted(selected_faces)
        
        # select the start vertex, select it if it only appears once (the start or end), 
        # else select the lowest index
        cnt = {}
        for u, v, _ in selected_faces:
            cnt[u] = cnt.get(u, 0) + 1
            cnt[v] = cnt.get(v, 0) + 1
        starts = []
        for vertex, num in cnt.items():
            if num == 1:
                starts.append(vertex)
        start_idx = min(starts) if len(starts) else selected_faces[0][0]
        
        res = [start_idx]
        while len(res) <= len(selected_faces):
            vertex = res[-1]
            for u_i, v_i, face_idx_i in selected_faces:
                if face_idx_i not in face_patch and vertex in (u_i, v_i):
                    u_i, v_i = (u_i, v_i) if vertex == u_i else (v_i, u_i)
                    res.append(v_i)
                    face_patch.add(face_idx_i)
                    break
            
            if res[-1] == vertex:
                break
            
        if fix_orient and len(res) >= 2 and not face_orient['{}-{}-{}'.format(max_deg_vertex, res[0], res[1])]:
            res = res[::-1]
        
        # reduce the degree of related vertices and mark the visited faces
        degrees[max_deg_vertex] = len(selected_faces) - len(res) + 1
        for pos_idx, vertex in enumerate(res):
            if pos_idx in [0, len(res) - 1]:
                degrees[vertex] -= 1
            else:
                degrees[vertex] -= 2
        for face_idx in face_patch:
            unvisited[face_idx] = False 
        sequence.extend(
            [mesh.vertices[max_deg_vertex]] + 
            [mesh.vertices[vertex_idx] for vertex_idx in res] + 
            [[special_token] * 3]
        )
        
    assert sum(degrees) == 0, 'All degrees should be zero'

    return np.array(sequence)



def get_block_representation(
        sequence, 
        block_size=8, 
        offset_size=16, 
        block_compressed=True, 
        special_token=-2, 
        use_special_block=True
    ):
    '''
    convert coordinates from Cartesian system to block indexes.
    '''
    special_block_base = block_size**3 + offset_size**3
    # prepare coordinates
    sp_mask = sequence != special_token
    sp_mask = np.all(sp_mask, axis=1)
    coords = sequence[sp_mask].reshape(-1, 3)
    coords = discretize(coords)

    # convert [x, y, z] to [block_id, offset_id]
    block_id = coords // offset_size
    block_id = block_id[:, 0] * block_size**2 + block_id[:, 1] * block_size + block_id[:, 2]
    offset_id = coords % offset_size
    offset_id = offset_id[:, 0] * offset_size**2 + offset_id[:, 1] * offset_size + offset_id[:, 2]
    offset_id += block_size**3
    block_coords = np.concatenate([block_id[..., None], offset_id[..., None]], axis=-1).astype(np.int64)
    sequence[:, :2][sp_mask] = block_coords
    sequence = sequence[:, :2]
    
    # convert to codes
    codes = []
    cur_block_id = sequence[0, 0]
    codes.append(cur_block_id)
    for i in range(len(sequence)):
        if sequence[i, 0] == special_token:
            if not use_special_block:
                codes.append(special_token)
            cur_block_id = special_token
            
        elif sequence[i, 0] == cur_block_id:
            if block_compressed:
                codes.append(sequence[i, 1])
            else:
                codes.extend([sequence[i, 0], sequence[i, 1]])
                
        else:
            if use_special_block and cur_block_id == special_token:
                block_id = sequence[i, 0] + special_block_base
            else:
                block_id = sequence[i, 0]
            codes.extend([block_id, sequence[i, 1]])
            cur_block_id = block_id

    codes = np.array(codes).astype(np.int64)
    sequence = codes
    
    return sequence.flatten()


def BPT_serialize(mesh: trimesh.Trimesh):
    # serialize mesh with BPT
    
    # 1. patchify faces into patches
    sequence = patchified_mesh(mesh, special_token=-2)
    
    # 2. convert coordinates to block-wise indexes
    codes = get_block_representation(
        sequence, block_size=8, offset_size=16, 
        block_compressed=True, special_token=-2, use_special_block=True
    )
    return codes    


def decode_block(sequence, compressed=True, block_size=8, offset_size=16):
    
    # decode from compressed representation
    if compressed:
        res = []
        res_block = 0
        for token_id in range(len(sequence)):                
            if block_size**3 + offset_size**3 > sequence[token_id] >= block_size**3:
                res.append([res_block, sequence[token_id]])
            elif block_size**3 > sequence[token_id] >= 0:
                res_block = sequence[token_id]
            else:
                print('[Warning] too large offset idx!', token_id, sequence[token_id])
        sequence = np.array(res)
    
    block_id, offset_id = np.array_split(sequence, 2, axis=-1)
    
    # from hash representation to xyz 
    coords = []
    offset_id -= block_size**3
    for i in [2, 1, 0]:
        axis = (block_id // block_size**i) * offset_size + (offset_id // offset_size**i)
        block_id %= block_size**i
        offset_id %= offset_size**i    
        coords.append(axis)
    
    coords = np.concatenate(coords, axis=-1) # (nf 3)
    
    # back to continuous space
    coords = undiscretize(coords)

    return coords


def BPT_deserialize(sequence, block_size=8, offset_size=16, compressed=True, special_token=-2, use_special_block=True):
    # decode codes back to coordinates

    special_block_base = block_size**3 + offset_size**3
    start_idx = 0
    vertices = []
    for i in range(len(sequence)):
        sub_seq = []
        if not use_special_block and (sequence[i] == special_token or i == len(sequence) - 1):
            sub_seq = sequence[start_idx:i]
            sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size)
            start_idx = i + 1

        elif use_special_block and \
            (special_block_base <= sequence[i] < special_block_base + block_size**3 or i == len(sequence)-1):
            if i != 0:
                sub_seq = sequence[start_idx:i] if i != len(sequence) - 1 else sequence[start_idx: i+1]
                if special_block_base <= sub_seq[0] < special_block_base + block_size**3:
                    sub_seq[0] -= special_block_base
                sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size)
                start_idx = i

        if len(sub_seq):
            center, sub_seq = sub_seq[0], sub_seq[1:]
            for j in range(len(sub_seq) - 1):
                vertices.extend([center.reshape(1, 3), sub_seq[j].reshape(1, 3), sub_seq[j+1].reshape(1, 3)])

    # (nf, 3)
    return np.concatenate(vertices, axis=0) 


if __name__ == '__main__':
    # a simple demo for serialize and deserialize mesh with bpt
    from data_utils import load_process_mesh, to_mesh
    import torch
    mesh = load_process_mesh('/path/to/your/mesh', quantization_bits=7)
    mesh['faces'] = np.array(mesh['faces'])
    mesh = to_mesh(mesh['vertices'], mesh['faces'], transpose=True)
    mesh.export('gt.obj')
    codes = BPT_serialize(mesh)
    coordinates = BPT_deserialize(codes)
    faces = torch.arange(1, len(coordinates) + 1).view(-1, 3)
    mesh = to_mesh(coordinates, faces, transpose=False, post_process=False)
    mesh.export('reconstructed.obj')