Spaces:
Runtime error
Runtime error
Update models/tag2text.py
Browse files- models/tag2text.py +3 -0
models/tag2text.py
CHANGED
@@ -25,6 +25,8 @@ import numpy as np
|
|
25 |
def read_json(rpath):
|
26 |
with open(rpath, 'r') as f:
|
27 |
return json.load(f)
|
|
|
|
|
28 |
|
29 |
class Tag2Text_Caption(nn.Module):
|
30 |
def __init__(self,
|
@@ -132,6 +134,7 @@ class Tag2Text_Caption(nn.Module):
|
|
132 |
targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
133 |
|
134 |
tag = targets.cpu().numpy()
|
|
|
135 |
bs = image.size(0)
|
136 |
tag_input = []
|
137 |
for b in range(bs):
|
|
|
25 |
def read_json(rpath):
|
26 |
with open(rpath, 'r') as f:
|
27 |
return json.load(f)
|
28 |
+
|
29 |
+
delete_tag_index = [135]
|
30 |
|
31 |
class Tag2Text_Caption(nn.Module):
|
32 |
def __init__(self,
|
|
|
134 |
targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
135 |
|
136 |
tag = targets.cpu().numpy()
|
137 |
+
tag[:,delete_tag_index] = 0
|
138 |
bs = image.size(0)
|
139 |
tag_input = []
|
140 |
for b in range(bs):
|