brydon commited on
Commit
90fe6fb
1 Parent(s): 4b49e24

Update PyTorch example in README.md

Browse files

The previous pytorch example was incomplete (did not show how to process the outputs) and was also not using the specific question answering version of distilbert.

Files changed (1) hide show
  1. README.md +7 -3
README.md CHANGED
@@ -64,10 +64,10 @@ Answer: 'SQuAD dataset', score: 0.4704, start: 147, end: 160
64
  Here is how to use this model in PyTorch:
65
 
66
  ```python
67
- from transformers import DistilBertTokenizer, DistilBertModel
68
  import torch
69
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
70
- model = DistilBertModel.from_pretrained('distilbert-base-uncased-distilled-squad')
71
 
72
  question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
73
 
@@ -75,7 +75,11 @@ inputs = tokenizer(question, text, return_tensors="pt")
75
  with torch.no_grad():
76
  outputs = model(**inputs)
77
 
78
- print(outputs)
 
 
 
 
79
  ```
80
 
81
  And in TensorFlow:
 
64
  Here is how to use this model in PyTorch:
65
 
66
  ```python
67
+ from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
68
  import torch
69
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
70
+ model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')
71
 
72
  question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
73
 
 
75
  with torch.no_grad():
76
  outputs = model(**inputs)
77
 
78
+ answer_start_index = torch.argmax(outputs.start_logits)
79
+ answer_end_index = torch.argmax(outputs.end_logits)
80
+
81
+ predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
82
+ tokenizer.decode(predict_answer_tokens)
83
  ```
84
 
85
  And in TensorFlow: