yuanze1024 commited on
Commit
4c05bb3
1 Parent(s): d65e4f4

remove unused code

Browse files
app.py CHANGED
@@ -1,10 +1,3 @@
1
- import subprocess
2
-
3
- # a workaround for gradio SDK
4
- subprocess.call(["pip", "install", "torch==2.1.0+cu118", "torchvision==0.16.0+cu118", "-i", "https://download.pytorch.org/whl/cu118"])
5
- subprocess.call(["git", "clone", "https://github.com/yuanze1024/Pointnet2_PyTorch.git"])
6
- subprocess.call(["pip", "install", "."], cwd="Pointnet2_PyTorch/pointnet2_ops_lib")
7
-
8
  import os
9
  import random
10
  import gradio as gr
@@ -153,7 +146,8 @@ The *Modality List* refers to the features ensembled by the retrieval methods. A
153
  Also, you may want to ckeck the 3D model in a 3D model viewer, in that case, you can visit [Objaverse](https://objaverse.allenai.org/explore) for exploration.""")
154
  with gr.Row():
155
  textual_query = gr.Textbox(label="Textual Query", autofocus=True, value="Super Mario")
156
- modality_list = gr.CheckboxGroup(label="Modality List", value=[],
 
157
  choices=["text", "front", "back", "left", "right", "above",
158
  "below", "diag_above", "diag_below", "3D"])
159
  with gr.Row():
 
 
 
 
 
 
 
 
1
  import os
2
  import random
3
  import gradio as gr
 
146
  Also, you may want to ckeck the 3D model in a 3D model viewer, in that case, you can visit [Objaverse](https://objaverse.allenai.org/explore) for exploration.""")
147
  with gr.Row():
148
  textual_query = gr.Textbox(label="Textual Query", autofocus=True, value="Super Mario")
149
+ modality_list = gr.CheckboxGroup(label="Modality List", value=["text", "front", "back", "left", "right", "above",
150
+ "below", "diag_above", "diag_below", "3D"],
151
  choices=["text", "front", "back", "left", "right", "above",
152
  "below", "diag_above", "diag_below", "3D"])
153
  with gr.Row():
feature_extractors/uni3d_embedding_encoder.py CHANGED
@@ -1,319 +1,37 @@
1
  """
2
- See https://github.com/baaivision/Uni3D for source code
 
 
3
  """
4
  import os
5
- import torch
6
- import torch.nn as nn
7
- import timm
8
- import numpy as np
9
- from pointnet2_ops import pointnet2_utils
10
  import open_clip
 
11
  from huggingface_hub import hf_hub_download
12
- import sys
13
  sys.path.append('')
14
  from feature_extractors import FeatureExtractor
15
  from utils.tokenizer import SimpleTokenizer
16
 
17
- import logging
18
-
19
- def fps(data, number):
20
- '''
21
- data B N 3
22
- number int
23
- '''
24
- fps_idx = pointnet2_utils.furthest_point_sample(data, number)
25
- fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
26
- return fps_data
27
-
28
- # https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py
29
- def knn_point(nsample, xyz, new_xyz):
30
- """
31
- Input:
32
- nsample: max sample number in local region
33
- xyz: all points, [B, N, C]
34
- new_xyz: query points, [B, S, C]
35
- Return:
36
- group_idx: grouped points index, [B, S, nsample]
37
- """
38
- sqrdists = square_distance(new_xyz, xyz)
39
- _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
40
- return group_idx
41
-
42
- def square_distance(src, dst):
43
- """
44
- Calculate Euclid distance between each two points.
45
- src^T * dst = xn * xm + yn * ym + zn * zm;
46
- sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
47
- sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
48
- dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
49
- = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
50
- Input:
51
- src: source points, [B, N, C]
52
- dst: target points, [B, M, C]
53
- Output:
54
- dist: per-point square distance, [B, N, M]
55
- """
56
- B, N, _ = src.shape
57
- _, M, _ = dst.shape
58
- dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
59
- dist += torch.sum(src ** 2, -1).view(B, N, 1)
60
- dist += torch.sum(dst ** 2, -1).view(B, 1, M)
61
- return dist
62
-
63
-
64
- class PatchDropout(nn.Module):
65
- """
66
- https://arxiv.org/abs/2212.00794
67
- """
68
-
69
- def __init__(self, prob, exclude_first_token=True):
70
- super().__init__()
71
- assert 0 <= prob < 1.
72
- self.prob = prob
73
- self.exclude_first_token = exclude_first_token # exclude CLS token
74
- logging.info("patch dropout prob is {}".format(prob))
75
-
76
- def forward(self, x):
77
- # if not self.training or self.prob == 0.:
78
- # return x
79
-
80
- if self.exclude_first_token:
81
- cls_tokens, x = x[:, :1], x[:, 1:]
82
- else:
83
- cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
84
-
85
- batch = x.size()[0]
86
- num_tokens = x.size()[1]
87
-
88
- batch_indices = torch.arange(batch)
89
- batch_indices = batch_indices[..., None]
90
-
91
- keep_prob = 1 - self.prob
92
- num_patches_keep = max(1, int(num_tokens * keep_prob))
93
-
94
- rand = torch.randn(batch, num_tokens)
95
- patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
96
-
97
- x = x[batch_indices, patch_indices_keep]
98
-
99
- if self.exclude_first_token:
100
- x = torch.cat((cls_tokens, x), dim=1)
101
-
102
- return x
103
-
104
-
105
- class Group(nn.Module):
106
- def __init__(self, num_group, group_size):
107
- super().__init__()
108
- self.num_group = num_group
109
- self.group_size = group_size
110
-
111
- def forward(self, xyz, color):
112
- '''
113
- input: B N 3
114
- ---------------------------
115
- output: B G M 3
116
- center : B G 3
117
- '''
118
- batch_size, num_points, _ = xyz.shape
119
- # fps the centers out
120
- center = fps(xyz, self.num_group) # B G 3
121
- # knn to get the neighborhood
122
- # _, idx = self.knn(xyz, center) # B G M
123
- idx = knn_point(self.group_size, xyz, center) # B G M
124
- assert idx.size(1) == self.num_group
125
- assert idx.size(2) == self.group_size
126
- idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
127
- idx = idx + idx_base
128
- idx = idx.view(-1)
129
- neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
130
- neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
131
-
132
- neighborhood_color = color.view(batch_size * num_points, -1)[idx, :]
133
- neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous()
134
-
135
- # normalize
136
- neighborhood = neighborhood - center.unsqueeze(2)
137
-
138
- features = torch.cat((neighborhood, neighborhood_color), dim=-1)
139
- return neighborhood, center, features
140
-
141
- class Encoder(nn.Module):
142
- def __init__(self, encoder_channel):
143
- super().__init__()
144
- self.encoder_channel = encoder_channel
145
- self.first_conv = nn.Sequential(
146
- nn.Conv1d(6, 128, 1),
147
- nn.BatchNorm1d(128),
148
- nn.ReLU(inplace=True),
149
- nn.Conv1d(128, 256, 1)
150
- )
151
- self.second_conv = nn.Sequential(
152
- nn.Conv1d(512, 512, 1),
153
- nn.BatchNorm1d(512),
154
- nn.ReLU(inplace=True),
155
- nn.Conv1d(512, self.encoder_channel, 1)
156
- )
157
- def forward(self, point_groups):
158
- '''
159
- point_groups : B G N 3
160
- -----------------
161
- feature_global : B G C
162
- '''
163
- bs, g, n , _ = point_groups.shape
164
- point_groups = point_groups.reshape(bs * g, n, 6)
165
- # encoder
166
- feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n
167
- feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1
168
- feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
169
- feature = self.second_conv(feature) # BG 1024 n
170
- feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
171
- return feature_global.reshape(bs, g, self.encoder_channel)
172
-
173
- class PointcloudEncoder(nn.Module):
174
- def __init__(self, point_transformer):
175
- # use the giant branch of uni3d
176
- super().__init__()
177
- from easydict import EasyDict
178
- self.trans_dim = 1408
179
- self.embed_dim = 1024
180
- self.group_size = 64
181
- self.num_group = 512
182
- # grouper
183
- self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
184
- # define the encoder
185
- self.encoder_dim = 512
186
- self.encoder = Encoder(encoder_channel = self.encoder_dim)
187
-
188
- # bridge encoder and transformer
189
- self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim)
190
-
191
- # bridge transformer and clip embedding
192
- self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim)
193
- self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
194
- self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
195
-
196
- self.pos_embed = nn.Sequential(
197
- nn.Linear(3, 128),
198
- nn.GELU(),
199
- nn.Linear(128, self.trans_dim)
200
- )
201
- # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
202
- self.patch_dropout = PatchDropout(0.) if 0. > 0. else nn.Identity()
203
- self.visual = point_transformer
204
-
205
-
206
- def forward(self, pts, colors):
207
- # divide the point cloud in the same form. This is important
208
- _, center, features = self.group_divider(pts, colors)
209
-
210
- # encoder the input cloud patches
211
- group_input_tokens = self.encoder(features) # B G N
212
- group_input_tokens = self.encoder2trans(group_input_tokens)
213
- # prepare cls
214
- cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
215
- cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
216
- # add pos embedding
217
- pos = self.pos_embed(center)
218
- # final input
219
- x = torch.cat((cls_tokens, group_input_tokens), dim=1)
220
- pos = torch.cat((cls_pos, pos), dim=1)
221
- # transformer
222
- x = x + pos
223
- # x = x.half()
224
-
225
- # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
226
- x = self.patch_dropout(x)
227
-
228
- x = self.visual.pos_drop(x)
229
-
230
- # ModuleList not support forward
231
- for i, blk in enumerate(self.visual.blocks):
232
- x = blk(x)
233
- x = self.visual.norm(x[:, 0, :])
234
- x = self.visual.fc_norm(x)
235
-
236
- x = self.trans2embed(x)
237
- return x
238
-
239
- class Uni3D(nn.Module):
240
- def __init__(self, point_encoder):
241
- super().__init__()
242
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
243
- self.point_encoder = point_encoder
244
-
245
- def encode_pc(self, pc):
246
- xyz = pc[:,:,:3].contiguous()
247
- color = pc[:,:,3:].contiguous()
248
- pc_feat = self.point_encoder(xyz, color)
249
- return pc_feat
250
-
251
- def forward(self, pc, text, image):
252
- text_embed_all = text
253
- image_embed = image
254
- pc_embed = self.encode_pc(pc)
255
- return {'text_embed': text_embed_all,
256
- 'pc_embed': pc_embed,
257
- 'image_embed': image_embed,
258
- 'logit_scale': self.logit_scale.exp()}
259
-
260
- def get_metric_names(model):
261
- return ['loss', 'uni3d_loss', 'pc_image_acc', 'pc_text_acc']
262
-
263
- def create_uni3d(uni3d_path):
264
- # create transformer blocks for point cloud via timm
265
- point_transformer = timm.create_model("eva_giant_patch14_560")
266
-
267
- # create whole point cloud encoder
268
- point_encoder = PointcloudEncoder(point_transformer)
269
-
270
- # uni3d model
271
- model = Uni3D(point_encoder=point_encoder,)
272
-
273
- checkpoint = torch.load(uni3d_path, map_location='cpu')
274
- logging.info('loaded checkpoint {}'.format(uni3d_path))
275
- sd = checkpoint['module']
276
- if next(iter(sd.items()))[0].startswith('module'):
277
- sd = {k[len('module.'):]: v for k, v in sd.items()}
278
- model.load_state_dict(sd)
279
- return model
280
 
281
  class Uni3dEmbeddingEncoder(FeatureExtractor):
282
  def __init__(self, cache_dir, **kwargs) -> None:
283
  bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
284
- # uni3d_path = os.path.join(cache_dir, "Uni3D", "modelzoo", "uni3d-g", "model.pt") # concat the subfolder as hf_hub_download will put it here
285
  clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
286
 
287
- # if not os.path.exists(uni3d_path):
288
- # hf_hub_download("BAAI/Uni3D", "model.pt", subfolder="modelzoo/uni3d-g", cache_dir=cache_dir,
289
- # local_dir=cache_dir + os.sep + "Uni3D")
290
  if not os.path.exists(clip_path):
291
  hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
292
  cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
293
 
294
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
295
  self.tokenizer = SimpleTokenizer(bpe_path)
296
- # self.model = create_uni3d(uni3d_path)
297
- # self.model.eval()
298
- # self.model.to(self.device)
299
  self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
300
  self.clip_model.to(self.device)
301
 
302
- def pc_norm(self, pc):
303
- """ pc: NxC, return NxC """
304
- centroid = np.mean(pc, axis=0)
305
- pc = pc - centroid
306
- m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
307
- pc = pc / m
308
- return pc
309
-
310
  @torch.no_grad()
311
  def encode_3D(self, data):
312
- pass
313
- # pc = data.to(device=self.device, non_blocking=True)
314
- # pc_features = self.model.encode_pc(pc)
315
- # pc_features = pc_features / pc_features.norm(dim=-1, keepdim=True)
316
- # return pc_features.float()
317
 
318
  @torch.no_grad()
319
  def encode_text(self, input_text):
 
1
  """
2
+ This is a modified version which only extract text embedding in HF Space.
3
+ See https://github.com/baaivision/Uni3D for source code.
4
+ Or refer to https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py for extracting all embeddings.
5
  """
6
  import os
7
+ import sys
8
+
 
 
 
9
  import open_clip
10
+ import torch
11
  from huggingface_hub import hf_hub_download
12
+
13
  sys.path.append('')
14
  from feature_extractors import FeatureExtractor
15
  from utils.tokenizer import SimpleTokenizer
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class Uni3dEmbeddingEncoder(FeatureExtractor):
19
  def __init__(self, cache_dir, **kwargs) -> None:
20
  bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
 
21
  clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
22
 
 
 
 
23
  if not os.path.exists(clip_path):
24
  hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
25
  cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
26
 
27
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  self.tokenizer = SimpleTokenizer(bpe_path)
 
 
 
29
  self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
30
  self.clip_model.to(self.device)
31
 
 
 
 
 
 
 
 
 
32
  @torch.no_grad()
33
  def encode_3D(self, data):
34
+ raise NotImplementedError("For extracting 3D feature, see https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py")
 
 
 
 
35
 
36
  @torch.no_grad()
37
  def encode_text(self, input_text):
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  gradio
 
 
2
  datasets
3
  timm
4
  pillow
 
1
  gradio
2
+ torch
3
+ torchvision
4
  datasets
5
  timm
6
  pillow