Spaces:
Sleeping
Sleeping
yuanze1024
commited on
Commit
•
4c05bb3
1
Parent(s):
d65e4f4
remove unused code
Browse files- app.py +2 -8
- feature_extractors/uni3d_embedding_encoder.py +8 -290
- requirements.txt +2 -0
app.py
CHANGED
@@ -1,10 +1,3 @@
|
|
1 |
-
import subprocess
|
2 |
-
|
3 |
-
# a workaround for gradio SDK
|
4 |
-
subprocess.call(["pip", "install", "torch==2.1.0+cu118", "torchvision==0.16.0+cu118", "-i", "https://download.pytorch.org/whl/cu118"])
|
5 |
-
subprocess.call(["git", "clone", "https://github.com/yuanze1024/Pointnet2_PyTorch.git"])
|
6 |
-
subprocess.call(["pip", "install", "."], cwd="Pointnet2_PyTorch/pointnet2_ops_lib")
|
7 |
-
|
8 |
import os
|
9 |
import random
|
10 |
import gradio as gr
|
@@ -153,7 +146,8 @@ The *Modality List* refers to the features ensembled by the retrieval methods. A
|
|
153 |
Also, you may want to ckeck the 3D model in a 3D model viewer, in that case, you can visit [Objaverse](https://objaverse.allenai.org/explore) for exploration.""")
|
154 |
with gr.Row():
|
155 |
textual_query = gr.Textbox(label="Textual Query", autofocus=True, value="Super Mario")
|
156 |
-
modality_list = gr.CheckboxGroup(label="Modality List", value=[
|
|
|
157 |
choices=["text", "front", "back", "left", "right", "above",
|
158 |
"below", "diag_above", "diag_below", "3D"])
|
159 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import random
|
3 |
import gradio as gr
|
|
|
146 |
Also, you may want to ckeck the 3D model in a 3D model viewer, in that case, you can visit [Objaverse](https://objaverse.allenai.org/explore) for exploration.""")
|
147 |
with gr.Row():
|
148 |
textual_query = gr.Textbox(label="Textual Query", autofocus=True, value="Super Mario")
|
149 |
+
modality_list = gr.CheckboxGroup(label="Modality List", value=["text", "front", "back", "left", "right", "above",
|
150 |
+
"below", "diag_above", "diag_below", "3D"],
|
151 |
choices=["text", "front", "back", "left", "right", "above",
|
152 |
"below", "diag_above", "diag_below", "3D"])
|
153 |
with gr.Row():
|
feature_extractors/uni3d_embedding_encoder.py
CHANGED
@@ -1,319 +1,37 @@
|
|
1 |
"""
|
2 |
-
|
|
|
|
|
3 |
"""
|
4 |
import os
|
5 |
-
import
|
6 |
-
|
7 |
-
import timm
|
8 |
-
import numpy as np
|
9 |
-
from pointnet2_ops import pointnet2_utils
|
10 |
import open_clip
|
|
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
-
|
13 |
sys.path.append('')
|
14 |
from feature_extractors import FeatureExtractor
|
15 |
from utils.tokenizer import SimpleTokenizer
|
16 |
|
17 |
-
import logging
|
18 |
-
|
19 |
-
def fps(data, number):
|
20 |
-
'''
|
21 |
-
data B N 3
|
22 |
-
number int
|
23 |
-
'''
|
24 |
-
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
|
25 |
-
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
|
26 |
-
return fps_data
|
27 |
-
|
28 |
-
# https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py
|
29 |
-
def knn_point(nsample, xyz, new_xyz):
|
30 |
-
"""
|
31 |
-
Input:
|
32 |
-
nsample: max sample number in local region
|
33 |
-
xyz: all points, [B, N, C]
|
34 |
-
new_xyz: query points, [B, S, C]
|
35 |
-
Return:
|
36 |
-
group_idx: grouped points index, [B, S, nsample]
|
37 |
-
"""
|
38 |
-
sqrdists = square_distance(new_xyz, xyz)
|
39 |
-
_, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
|
40 |
-
return group_idx
|
41 |
-
|
42 |
-
def square_distance(src, dst):
|
43 |
-
"""
|
44 |
-
Calculate Euclid distance between each two points.
|
45 |
-
src^T * dst = xn * xm + yn * ym + zn * zm;
|
46 |
-
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
47 |
-
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
48 |
-
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
49 |
-
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
50 |
-
Input:
|
51 |
-
src: source points, [B, N, C]
|
52 |
-
dst: target points, [B, M, C]
|
53 |
-
Output:
|
54 |
-
dist: per-point square distance, [B, N, M]
|
55 |
-
"""
|
56 |
-
B, N, _ = src.shape
|
57 |
-
_, M, _ = dst.shape
|
58 |
-
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
59 |
-
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
60 |
-
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
61 |
-
return dist
|
62 |
-
|
63 |
-
|
64 |
-
class PatchDropout(nn.Module):
|
65 |
-
"""
|
66 |
-
https://arxiv.org/abs/2212.00794
|
67 |
-
"""
|
68 |
-
|
69 |
-
def __init__(self, prob, exclude_first_token=True):
|
70 |
-
super().__init__()
|
71 |
-
assert 0 <= prob < 1.
|
72 |
-
self.prob = prob
|
73 |
-
self.exclude_first_token = exclude_first_token # exclude CLS token
|
74 |
-
logging.info("patch dropout prob is {}".format(prob))
|
75 |
-
|
76 |
-
def forward(self, x):
|
77 |
-
# if not self.training or self.prob == 0.:
|
78 |
-
# return x
|
79 |
-
|
80 |
-
if self.exclude_first_token:
|
81 |
-
cls_tokens, x = x[:, :1], x[:, 1:]
|
82 |
-
else:
|
83 |
-
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
84 |
-
|
85 |
-
batch = x.size()[0]
|
86 |
-
num_tokens = x.size()[1]
|
87 |
-
|
88 |
-
batch_indices = torch.arange(batch)
|
89 |
-
batch_indices = batch_indices[..., None]
|
90 |
-
|
91 |
-
keep_prob = 1 - self.prob
|
92 |
-
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
93 |
-
|
94 |
-
rand = torch.randn(batch, num_tokens)
|
95 |
-
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
96 |
-
|
97 |
-
x = x[batch_indices, patch_indices_keep]
|
98 |
-
|
99 |
-
if self.exclude_first_token:
|
100 |
-
x = torch.cat((cls_tokens, x), dim=1)
|
101 |
-
|
102 |
-
return x
|
103 |
-
|
104 |
-
|
105 |
-
class Group(nn.Module):
|
106 |
-
def __init__(self, num_group, group_size):
|
107 |
-
super().__init__()
|
108 |
-
self.num_group = num_group
|
109 |
-
self.group_size = group_size
|
110 |
-
|
111 |
-
def forward(self, xyz, color):
|
112 |
-
'''
|
113 |
-
input: B N 3
|
114 |
-
---------------------------
|
115 |
-
output: B G M 3
|
116 |
-
center : B G 3
|
117 |
-
'''
|
118 |
-
batch_size, num_points, _ = xyz.shape
|
119 |
-
# fps the centers out
|
120 |
-
center = fps(xyz, self.num_group) # B G 3
|
121 |
-
# knn to get the neighborhood
|
122 |
-
# _, idx = self.knn(xyz, center) # B G M
|
123 |
-
idx = knn_point(self.group_size, xyz, center) # B G M
|
124 |
-
assert idx.size(1) == self.num_group
|
125 |
-
assert idx.size(2) == self.group_size
|
126 |
-
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
|
127 |
-
idx = idx + idx_base
|
128 |
-
idx = idx.view(-1)
|
129 |
-
neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
|
130 |
-
neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
|
131 |
-
|
132 |
-
neighborhood_color = color.view(batch_size * num_points, -1)[idx, :]
|
133 |
-
neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous()
|
134 |
-
|
135 |
-
# normalize
|
136 |
-
neighborhood = neighborhood - center.unsqueeze(2)
|
137 |
-
|
138 |
-
features = torch.cat((neighborhood, neighborhood_color), dim=-1)
|
139 |
-
return neighborhood, center, features
|
140 |
-
|
141 |
-
class Encoder(nn.Module):
|
142 |
-
def __init__(self, encoder_channel):
|
143 |
-
super().__init__()
|
144 |
-
self.encoder_channel = encoder_channel
|
145 |
-
self.first_conv = nn.Sequential(
|
146 |
-
nn.Conv1d(6, 128, 1),
|
147 |
-
nn.BatchNorm1d(128),
|
148 |
-
nn.ReLU(inplace=True),
|
149 |
-
nn.Conv1d(128, 256, 1)
|
150 |
-
)
|
151 |
-
self.second_conv = nn.Sequential(
|
152 |
-
nn.Conv1d(512, 512, 1),
|
153 |
-
nn.BatchNorm1d(512),
|
154 |
-
nn.ReLU(inplace=True),
|
155 |
-
nn.Conv1d(512, self.encoder_channel, 1)
|
156 |
-
)
|
157 |
-
def forward(self, point_groups):
|
158 |
-
'''
|
159 |
-
point_groups : B G N 3
|
160 |
-
-----------------
|
161 |
-
feature_global : B G C
|
162 |
-
'''
|
163 |
-
bs, g, n , _ = point_groups.shape
|
164 |
-
point_groups = point_groups.reshape(bs * g, n, 6)
|
165 |
-
# encoder
|
166 |
-
feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n
|
167 |
-
feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1
|
168 |
-
feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
|
169 |
-
feature = self.second_conv(feature) # BG 1024 n
|
170 |
-
feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
|
171 |
-
return feature_global.reshape(bs, g, self.encoder_channel)
|
172 |
-
|
173 |
-
class PointcloudEncoder(nn.Module):
|
174 |
-
def __init__(self, point_transformer):
|
175 |
-
# use the giant branch of uni3d
|
176 |
-
super().__init__()
|
177 |
-
from easydict import EasyDict
|
178 |
-
self.trans_dim = 1408
|
179 |
-
self.embed_dim = 1024
|
180 |
-
self.group_size = 64
|
181 |
-
self.num_group = 512
|
182 |
-
# grouper
|
183 |
-
self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
|
184 |
-
# define the encoder
|
185 |
-
self.encoder_dim = 512
|
186 |
-
self.encoder = Encoder(encoder_channel = self.encoder_dim)
|
187 |
-
|
188 |
-
# bridge encoder and transformer
|
189 |
-
self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim)
|
190 |
-
|
191 |
-
# bridge transformer and clip embedding
|
192 |
-
self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim)
|
193 |
-
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
|
194 |
-
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
|
195 |
-
|
196 |
-
self.pos_embed = nn.Sequential(
|
197 |
-
nn.Linear(3, 128),
|
198 |
-
nn.GELU(),
|
199 |
-
nn.Linear(128, self.trans_dim)
|
200 |
-
)
|
201 |
-
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
202 |
-
self.patch_dropout = PatchDropout(0.) if 0. > 0. else nn.Identity()
|
203 |
-
self.visual = point_transformer
|
204 |
-
|
205 |
-
|
206 |
-
def forward(self, pts, colors):
|
207 |
-
# divide the point cloud in the same form. This is important
|
208 |
-
_, center, features = self.group_divider(pts, colors)
|
209 |
-
|
210 |
-
# encoder the input cloud patches
|
211 |
-
group_input_tokens = self.encoder(features) # B G N
|
212 |
-
group_input_tokens = self.encoder2trans(group_input_tokens)
|
213 |
-
# prepare cls
|
214 |
-
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
|
215 |
-
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
|
216 |
-
# add pos embedding
|
217 |
-
pos = self.pos_embed(center)
|
218 |
-
# final input
|
219 |
-
x = torch.cat((cls_tokens, group_input_tokens), dim=1)
|
220 |
-
pos = torch.cat((cls_pos, pos), dim=1)
|
221 |
-
# transformer
|
222 |
-
x = x + pos
|
223 |
-
# x = x.half()
|
224 |
-
|
225 |
-
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
226 |
-
x = self.patch_dropout(x)
|
227 |
-
|
228 |
-
x = self.visual.pos_drop(x)
|
229 |
-
|
230 |
-
# ModuleList not support forward
|
231 |
-
for i, blk in enumerate(self.visual.blocks):
|
232 |
-
x = blk(x)
|
233 |
-
x = self.visual.norm(x[:, 0, :])
|
234 |
-
x = self.visual.fc_norm(x)
|
235 |
-
|
236 |
-
x = self.trans2embed(x)
|
237 |
-
return x
|
238 |
-
|
239 |
-
class Uni3D(nn.Module):
|
240 |
-
def __init__(self, point_encoder):
|
241 |
-
super().__init__()
|
242 |
-
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
243 |
-
self.point_encoder = point_encoder
|
244 |
-
|
245 |
-
def encode_pc(self, pc):
|
246 |
-
xyz = pc[:,:,:3].contiguous()
|
247 |
-
color = pc[:,:,3:].contiguous()
|
248 |
-
pc_feat = self.point_encoder(xyz, color)
|
249 |
-
return pc_feat
|
250 |
-
|
251 |
-
def forward(self, pc, text, image):
|
252 |
-
text_embed_all = text
|
253 |
-
image_embed = image
|
254 |
-
pc_embed = self.encode_pc(pc)
|
255 |
-
return {'text_embed': text_embed_all,
|
256 |
-
'pc_embed': pc_embed,
|
257 |
-
'image_embed': image_embed,
|
258 |
-
'logit_scale': self.logit_scale.exp()}
|
259 |
-
|
260 |
-
def get_metric_names(model):
|
261 |
-
return ['loss', 'uni3d_loss', 'pc_image_acc', 'pc_text_acc']
|
262 |
-
|
263 |
-
def create_uni3d(uni3d_path):
|
264 |
-
# create transformer blocks for point cloud via timm
|
265 |
-
point_transformer = timm.create_model("eva_giant_patch14_560")
|
266 |
-
|
267 |
-
# create whole point cloud encoder
|
268 |
-
point_encoder = PointcloudEncoder(point_transformer)
|
269 |
-
|
270 |
-
# uni3d model
|
271 |
-
model = Uni3D(point_encoder=point_encoder,)
|
272 |
-
|
273 |
-
checkpoint = torch.load(uni3d_path, map_location='cpu')
|
274 |
-
logging.info('loaded checkpoint {}'.format(uni3d_path))
|
275 |
-
sd = checkpoint['module']
|
276 |
-
if next(iter(sd.items()))[0].startswith('module'):
|
277 |
-
sd = {k[len('module.'):]: v for k, v in sd.items()}
|
278 |
-
model.load_state_dict(sd)
|
279 |
-
return model
|
280 |
|
281 |
class Uni3dEmbeddingEncoder(FeatureExtractor):
|
282 |
def __init__(self, cache_dir, **kwargs) -> None:
|
283 |
bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
|
284 |
-
# uni3d_path = os.path.join(cache_dir, "Uni3D", "modelzoo", "uni3d-g", "model.pt") # concat the subfolder as hf_hub_download will put it here
|
285 |
clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
|
286 |
|
287 |
-
# if not os.path.exists(uni3d_path):
|
288 |
-
# hf_hub_download("BAAI/Uni3D", "model.pt", subfolder="modelzoo/uni3d-g", cache_dir=cache_dir,
|
289 |
-
# local_dir=cache_dir + os.sep + "Uni3D")
|
290 |
if not os.path.exists(clip_path):
|
291 |
hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
|
292 |
cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
|
293 |
|
294 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
295 |
self.tokenizer = SimpleTokenizer(bpe_path)
|
296 |
-
# self.model = create_uni3d(uni3d_path)
|
297 |
-
# self.model.eval()
|
298 |
-
# self.model.to(self.device)
|
299 |
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
|
300 |
self.clip_model.to(self.device)
|
301 |
|
302 |
-
def pc_norm(self, pc):
|
303 |
-
""" pc: NxC, return NxC """
|
304 |
-
centroid = np.mean(pc, axis=0)
|
305 |
-
pc = pc - centroid
|
306 |
-
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
|
307 |
-
pc = pc / m
|
308 |
-
return pc
|
309 |
-
|
310 |
@torch.no_grad()
|
311 |
def encode_3D(self, data):
|
312 |
-
|
313 |
-
# pc = data.to(device=self.device, non_blocking=True)
|
314 |
-
# pc_features = self.model.encode_pc(pc)
|
315 |
-
# pc_features = pc_features / pc_features.norm(dim=-1, keepdim=True)
|
316 |
-
# return pc_features.float()
|
317 |
|
318 |
@torch.no_grad()
|
319 |
def encode_text(self, input_text):
|
|
|
1 |
"""
|
2 |
+
This is a modified version which only extract text embedding in HF Space.
|
3 |
+
See https://github.com/baaivision/Uni3D for source code.
|
4 |
+
Or refer to https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py for extracting all embeddings.
|
5 |
"""
|
6 |
import os
|
7 |
+
import sys
|
8 |
+
|
|
|
|
|
|
|
9 |
import open_clip
|
10 |
+
import torch
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
+
|
13 |
sys.path.append('')
|
14 |
from feature_extractors import FeatureExtractor
|
15 |
from utils.tokenizer import SimpleTokenizer
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class Uni3dEmbeddingEncoder(FeatureExtractor):
|
19 |
def __init__(self, cache_dir, **kwargs) -> None:
|
20 |
bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
|
|
|
21 |
clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
|
22 |
|
|
|
|
|
|
|
23 |
if not os.path.exists(clip_path):
|
24 |
hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
|
25 |
cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
|
26 |
|
27 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
self.tokenizer = SimpleTokenizer(bpe_path)
|
|
|
|
|
|
|
29 |
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
|
30 |
self.clip_model.to(self.device)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
@torch.no_grad()
|
33 |
def encode_3D(self, data):
|
34 |
+
raise NotImplementedError("For extracting 3D feature, see https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py")
|
|
|
|
|
|
|
|
|
35 |
|
36 |
@torch.no_grad()
|
37 |
def encode_text(self, input_text):
|
requirements.txt
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
gradio
|
|
|
|
|
2 |
datasets
|
3 |
timm
|
4 |
pillow
|
|
|
1 |
gradio
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
datasets
|
5 |
timm
|
6 |
pillow
|