Update PyTorch example in README.md
Browse filesThe previous pytorch example was incomplete (did not show how to process the outputs) and was also not using the specific question answering version of distilbert.
README.md
CHANGED
@@ -64,10 +64,10 @@ Answer: 'SQuAD dataset', score: 0.4704, start: 147, end: 160
|
|
64 |
Here is how to use this model in PyTorch:
|
65 |
|
66 |
```python
|
67 |
-
from transformers import DistilBertTokenizer,
|
68 |
import torch
|
69 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
|
70 |
-
model =
|
71 |
|
72 |
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
73 |
|
@@ -75,7 +75,11 @@ inputs = tokenizer(question, text, return_tensors="pt")
|
|
75 |
with torch.no_grad():
|
76 |
outputs = model(**inputs)
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
```
|
80 |
|
81 |
And in TensorFlow:
|
|
|
64 |
Here is how to use this model in PyTorch:
|
65 |
|
66 |
```python
|
67 |
+
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
|
68 |
import torch
|
69 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
|
70 |
+
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')
|
71 |
|
72 |
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
73 |
|
|
|
75 |
with torch.no_grad():
|
76 |
outputs = model(**inputs)
|
77 |
|
78 |
+
answer_start_index = torch.argmax(outputs.start_logits)
|
79 |
+
answer_end_index = torch.argmax(outputs.end_logits)
|
80 |
+
|
81 |
+
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
|
82 |
+
tokenizer.decode(predict_answer_tokens)
|
83 |
```
|
84 |
|
85 |
And in TensorFlow:
|