Spaces:
Runtime error
Runtime error
Fix device error when using cuda (#4)
Browse files- Fix device error when using cuda (030c843a1bb758298a3f0bc6f2564a26aaff878e)
Co-authored-by: Ma Jinyu <[email protected]>
- models/tag2text.py +2 -3
models/tag2text.py
CHANGED
@@ -152,8 +152,7 @@ class RAM(nn.Module):
|
|
152 |
self.class_threshold[key] = value
|
153 |
|
154 |
def load_tag_list(self, tag_list_file):
|
155 |
-
with open(tag_list_file, 'r', encoding="
|
156 |
-
# with open(tag_list_file, 'r') as f:
|
157 |
tag_list = f.read().splitlines()
|
158 |
tag_list = np.array(tag_list)
|
159 |
return tag_list
|
@@ -362,7 +361,7 @@ class Tag2Text_Caption(nn.Module):
|
|
362 |
logits = self.fc(tagging_embed[0])
|
363 |
|
364 |
targets = torch.where(
|
365 |
-
torch.sigmoid(logits) > self.class_threshold,
|
366 |
torch.tensor(1.0).to(image.device),
|
367 |
torch.zeros(self.num_class).to(image.device))
|
368 |
|
|
|
152 |
self.class_threshold[key] = value
|
153 |
|
154 |
def load_tag_list(self, tag_list_file):
|
155 |
+
with open(tag_list_file, 'r', encoding="utf8") as f:
|
|
|
156 |
tag_list = f.read().splitlines()
|
157 |
tag_list = np.array(tag_list)
|
158 |
return tag_list
|
|
|
361 |
logits = self.fc(tagging_embed[0])
|
362 |
|
363 |
targets = torch.where(
|
364 |
+
torch.sigmoid(logits) > self.class_threshold.to(image.device),
|
365 |
torch.tensor(1.0).to(image.device),
|
366 |
torch.zeros(self.num_class).to(image.device))
|
367 |
|