Add code example
Browse files
README.md
CHANGED
@@ -9,17 +9,57 @@ Vision-and-Language Transformer (ViLT) model pre-trained on GCC+SBU+COCO+VG (200
|
|
9 |
|
10 |
Disclaimer: The team releasing ViLT did not write a model card for this model so this model card has been written by the Hugging Face team.
|
11 |
|
12 |
-
## Model description
|
13 |
-
|
14 |
-
(to do)
|
15 |
-
|
16 |
## Intended uses & limitations
|
17 |
|
18 |
-
You can use the raw model for
|
19 |
|
20 |
### How to use
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
## Training data
|
25 |
|
|
|
9 |
|
10 |
Disclaimer: The team releasing ViLT did not write a model card for this model so this model card has been written by the Hugging Face team.
|
11 |
|
|
|
|
|
|
|
|
|
12 |
## Intended uses & limitations
|
13 |
|
14 |
+
You can use the raw model for masked language modeling given an image and a piece of text with [MASK] tokens.
|
15 |
|
16 |
### How to use
|
17 |
|
18 |
+
Here is how to use this model in PyTorch:
|
19 |
+
|
20 |
+
```
|
21 |
+
from transformers import ViltProcessor, ViltForMaskedLM
|
22 |
+
import requests
|
23 |
+
from PIL import Image
|
24 |
+
import re
|
25 |
+
|
26 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
27 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
28 |
+
text = "a bunch of [MASK] laying on a [MASK]."
|
29 |
+
|
30 |
+
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
|
31 |
+
model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
|
32 |
+
|
33 |
+
# prepare inputs
|
34 |
+
encoding = processor(image, text, return_tensors="pt")
|
35 |
+
|
36 |
+
# forward pass
|
37 |
+
outputs = model(**encoding)
|
38 |
+
|
39 |
+
tl = len(re.findall("\[MASK\]", text))
|
40 |
+
inferred_token = [text]
|
41 |
+
|
42 |
+
# gradually fill in the MASK tokens, one by one
|
43 |
+
with torch.no_grad():
|
44 |
+
for i in range(tl):
|
45 |
+
encoded = processor.tokenizer(inferred_token)
|
46 |
+
input_ids = torch.tensor(encoded.input_ids).to(device)
|
47 |
+
encoded = encoded["input_ids"][0][1:-1]
|
48 |
+
outputs = model(input_ids=input_ids, pixel_values=pixel_values)
|
49 |
+
mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
|
50 |
+
# only take into account text features (minus CLS and SEP token)
|
51 |
+
mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
|
52 |
+
mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
|
53 |
+
# only take into account text
|
54 |
+
mlm_values[torch.tensor(encoded) != 103] = 0
|
55 |
+
select = mlm_values.argmax().item()
|
56 |
+
encoded[select] = mlm_ids[select].item()
|
57 |
+
inferred_token = [processor.decode(encoded)]
|
58 |
+
|
59 |
+
selected_token = ""
|
60 |
+
encoded = processor.tokenizer(inferred_token)
|
61 |
+
processor.decode(encoded.input_ids[0], skip_special_tokens=True)
|
62 |
+
```
|
63 |
|
64 |
## Training data
|
65 |
|