mtreviso commited on
Commit
7400ba7
1 Parent(s): b8f38cd

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +131 -0
README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: afl-3.0
3
+ language: en
4
+ tags:
5
+ - t5
6
+ datasets:
7
+ - wikipedia
8
+ ---
9
+
10
+ # chunked T5 - small (cT5-small)
11
+
12
+ Github: https://github.com/mtreviso/chunked-t5
13
+
14
+ A T5 model that uses a new loss where a special end-of-chunk token `</c>` is appended after sentinel tokens.
15
+ The decoder has to predict the full input with masked tokens followed by `</c>`.
16
+ This allows a much faster auto-regressive generation since the decoder can predict multiple tokens in parallel.
17
+
18
+ For example, for the input `the quick brown fox jumps over the lazy dog`:
19
+ ```
20
+ encoder: the <extra_id_0> fox jumps <extra_id_1> the lazy dog
21
+
22
+ T5 decoder : <extra_id_0> quick brown <extra_id_1> over <extra_id_2>
23
+ cT5 decoder: <extra_id_0> quick brown </c> <extra_id_1> over </c> <extra_id_2>
24
+ ```
25
+
26
+ The generation may look like this for T5 and cT5:
27
+ ```
28
+ T5: <extra_id_0>
29
+ T5: <extra_id_0> quick
30
+ T5: <extra_id_0> quick brown
31
+ T5: <extra_id_0> quick brown <extra_id_1>
32
+ T5: <extra_id_0> quick brown <extra_id_1> over
33
+ T5: <extra_id_0> quick brown <extra_id_1> over <extra_id_2>
34
+ T5: <extra_id_0> quick brown <extra_id_1> over <extra_id_2> </s>
35
+
36
+ cT5: <extra_id_0> <pad> <extra_id_1> <pad> <extra_id_2> </s>
37
+ cT5: <extra_id_0> quick <pad> <extra_id_1> over <pad> <extra_id_2> </s>
38
+ cT5: <extra_id_0> quick brown <pad> <extra_id_1> over </c> <extra_id_2> </s>
39
+ cT5: <extra_id_0> quick brown </c> <extra_id_1> over </c> <extra_id_2> </s>
40
+ ```
41
+
42
+ In the original T5, the decoder is called \\(n_s + 1 + \sum_i |s_i|\\) times autoregressively,
43
+ where \\(n_s\\) is the number of sentinel tokens and \\(s_1,...,s_{n_s}\\) are the predicted chunks.
44
+ In contrast, cT5's decoder is called just \\(max_i |s_i| + 1\\) times.
45
+ The generation stops when all sentences were fully translated to complete chunks, i.e., until all `</c>` tokens were generated.
46
+ Alternatively, you can also set `max_chunk_size` to manually force the model to stop after generating a chunk with `max_chunk_size` tokens.
47
+ The overhead of calling the decoder with a longer input is less pronounced since this computation can be parallelized in GPUs/TPUs.
48
+
49
+ ## Training details
50
+
51
+ cT5 models used T5's weights as a starting point, and then it was finetuned on the
52
+ English [wikipedia](https://huggingface.co/datasets/wikipedia) for 3 epochs,
53
+ achieving ~74% validation accuracy (ct5-small).
54
+ The training script is in JAX + Flax and can be found in `pretrain_ct5.py`.
55
+
56
+ Flax checkpoints can be converted to PyTorch via `convert_flax_to_pytorch.py [flax_dirname]`.
57
+
58
+
59
+ ## Checkpoints
60
+
61
+ - ct5-small: https://huggingface.co/mtreviso/ct5-small-en-wiki
62
+ - ct5-base: todo
63
+ - ct5-large: todo
64
+
65
+
66
+ ## Usage
67
+
68
+ ```python
69
+ from transformers import AutoTokenizer
70
+ from modeling_ct5 import CT5ForConditionalGeneration
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained("mtreviso/ct5-small-en-wiki")
73
+ model = CT5ForConditionalGeneration.from_pretrained("mtreviso/ct5-small-en-wiki")
74
+ ```
75
+
76
+ For training:
77
+
78
+ ```python
79
+ input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
80
+ labels = tokenizer("<extra_id_0> man </c> <extra_id_1> the </c> <extra_id_2>", return_tensors="pt").input_ids
81
+ outputs = model(input_ids=input_ids, labels=labels)
82
+ loss = outputs.loss
83
+ logits = outputs.logits
84
+ ```
85
+
86
+ For generation:
87
+
88
+ ```python
89
+ texts = [
90
+ "The <extra_id_0> walks in <extra_id_1> park",
91
+ "UN Chief says there is no way to <extra_id_0> in Syria",
92
+ ]
93
+ input_ids = tokenizer(texts, return_tensors="pt", padding=True).input_ids
94
+ generated_ids = model.generate(
95
+ input_ids,
96
+ use_cache=False, # important to set to False to avoid caching
97
+ eoc_token_id=tokenizer.vocab['</c>'], # important to set to the correct end-of-chunk id
98
+ max_chunk_size=5, # the default is 9999999, which is a large number
99
+ )
100
+ ```
101
+
102
+ This will produce the following tokens:
103
+ ```python
104
+ >> ['<pad>', '<extra_id_0>', '▁Walking', '▁Trail', '</c>', '<extra_id_1>', '▁the', '</c>', '<extra_id_2>', '</s>']
105
+ >> ['<pad>', '<extra_id_0>', '▁treat', '▁Syria', '</c>', '<extra_id_1>', '</s>', '<pad>', '<pad>', '<pad>']
106
+ ```
107
+
108
+ You have to pass `use_cache=False` to `generate()` in order to avoid caching during the generation procedure as caching is not available for parallel decoding.
109
+ Currently, parallel decoding is only supported for PyTorch (greedy search, greedy sampling, beam search, beam sampling) and JAX (greedy search and greedy sampling).
110
+
111
+ **Note on the beam search implementation**: my beam search implementation is slower than optimal.
112
+ This is because I use the structures provided by HuggingFace's implementation, namely, BeamScores and BeamHypotheses to store the beam search results for each chunk in the input.
113
+ In other words, my implementation computes independent "beams" for each chunk rather than for each input sequence.
114
+ It is possible to make it faster by using a custom BeamScores and BeamHypotheses class, but I haven't done that yet.
115
+
116
+
117
+ ## Evaluation
118
+
119
+ See the notebook `evaluate_ct5.ipynb` for an example of how to evaluate cT5 in terms of accuracy and perplexity.
120
+ The notebook `profile.ipynb` shows how to profile the model to get runtimes.
121
+
122
+ Here is a comparison between cT5-small and T5-small on a subset of the WikiText-103 dataset using deterministic greedy search:
123
+
124
+ | Model | Exact match ↑ | Edit distance ratio ↑ | Perplexity ↓ | Time (seconds) ↓ |
125
+ |-------|---------------|----------------------|--------------|-----------------|
126
+ | T5-small | 0.11 | 0.60 | 2.22 | 44.71 |
127
+ | cT5-small | 0.09 | 0.58 | 1.48 | 10.63 |
128
+
129
+ On this toy dataset, cT5-small has a lower perplexity while being faster than T5-small. However, more experiments are needed for a rigorous evaluation.
130
+
131
+ If you are interested in applying cT5 to real data, please contact me.