update to demo.py
Browse files
demo.py
CHANGED
@@ -33,6 +33,10 @@ class VideoModel(nn.Module):
|
|
33 |
super(VideoModel, self).__init__()
|
34 |
self.cfg = load_cfg(config)
|
35 |
self.model = self.build_model()
|
|
|
|
|
|
|
|
|
36 |
self.templates = ['{}']
|
37 |
self.dataset = self.cfg['data']['dataset']
|
38 |
self.eval()
|
@@ -156,7 +160,7 @@ class VideoCLSModel(VideoModel):
|
|
156 |
truncation=True,
|
157 |
max_length=self.model_cfg.max_txt_l.video,
|
158 |
return_tensors="pt",
|
159 |
-
)
|
160 |
_, class_embeddings = self.model.encode_text(embeddings)
|
161 |
return class_embeddings
|
162 |
|
@@ -170,9 +174,7 @@ class VideoCLSModel(VideoModel):
|
|
170 |
images = values[0]
|
171 |
target = values[1]
|
172 |
|
173 |
-
|
174 |
-
images = images.cuda(non_blocking=True)
|
175 |
-
target = target.cuda(non_blocking=True)
|
176 |
|
177 |
# encode images
|
178 |
images = rearrange(images, 'b c k h w -> b k c h w')
|
@@ -190,7 +192,7 @@ class VideoCLSModel(VideoModel):
|
|
190 |
similarity = self.model.get_sim(image_features, self.text_features)[0]
|
191 |
|
192 |
all_outputs.append(similarity.cpu())
|
193 |
-
all_targets.append(target
|
194 |
|
195 |
all_outputs = torch.cat(all_outputs)
|
196 |
all_targets = torch.cat(all_targets)
|
|
|
33 |
super(VideoModel, self).__init__()
|
34 |
self.cfg = load_cfg(config)
|
35 |
self.model = self.build_model()
|
36 |
+
use_gpu = torch.cuda.is_available()
|
37 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
if use_gpu:
|
39 |
+
self.model = self.model.to(self.device)
|
40 |
self.templates = ['{}']
|
41 |
self.dataset = self.cfg['data']['dataset']
|
42 |
self.eval()
|
|
|
160 |
truncation=True,
|
161 |
max_length=self.model_cfg.max_txt_l.video,
|
162 |
return_tensors="pt",
|
163 |
+
).to(self.device)
|
164 |
_, class_embeddings = self.model.encode_text(embeddings)
|
165 |
return class_embeddings
|
166 |
|
|
|
174 |
images = values[0]
|
175 |
target = values[1]
|
176 |
|
177 |
+
images = images.to(self.device)
|
|
|
|
|
178 |
|
179 |
# encode images
|
180 |
images = rearrange(images, 'b c k h w -> b k c h w')
|
|
|
192 |
similarity = self.model.get_sim(image_features, self.text_features)[0]
|
193 |
|
194 |
all_outputs.append(similarity.cpu())
|
195 |
+
all_targets.append(target)
|
196 |
|
197 |
all_outputs = torch.cat(all_outputs)
|
198 |
all_targets = torch.cat(all_targets)
|