Multiple choice
A multiple choice task is similar to question answering, except several candidate answers are provided along with a context. The model is trained to select the correct answer from multiple inputs given a context.
This guide will show you how to fine-tune BERT on the regular
configuration of the SWAG dataset to select the best answer given multiple options and some context.
Load SWAG dataset
Load the SWAG dataset from the 🤗 Datasets library:
>>> from datasets import load_dataset
>>> swag = load_dataset("swag", "regular")
Then take a look at an example:
>>> swag["train"][0]
{'ending0': 'passes by walking down the street playing their instruments.',
'ending1': 'has heard approaching them.',
'ending2': "arrives and they're outside dancing and asleep.",
'ending3': 'turns the lead singer watches the performance.',
'fold-ind': '3416',
'gold-source': 'gold',
'label': 0,
'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',
'sent2': 'A drum line',
'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',
'video-id': 'anetv_jkn6uvmqwh4'}
The sent1
and sent2
fields show how a sentence begins, and each ending
field shows how a sentence could end. Given the sentence beginning, the model must pick the correct sentence ending as indicated by the label
field.
Preprocess
Load the BERT tokenizer to process the start of each sentence and the four possible endings:
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
The preprocessing function needs to do:
- Make four copies of the
sent1
field so you can combine each of them withsent2
to recreate how a sentence starts. - Combine
sent2
with each of the four possible sentence endings. - Flatten these two lists so you can tokenize them, and then unflatten them afterward so each example has a corresponding
input_ids
,attention_mask
, andlabels
field.
>>> ending_names = ["ending0", "ending1", "ending2", "ending3"]
>>> def preprocess_function(examples):
... first_sentences = [[context] * 4 for context in examples["sent1"]]
... question_headers = examples["sent2"]
... second_sentences = [
... [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
... ]
... first_sentences = sum(first_sentences, [])
... second_sentences = sum(second_sentences, [])
... tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
... return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
Use 🤗 Datasets map
function to apply the preprocessing function over the entire dataset. You can speed up the map
function by setting batched=True
to process multiple elements of the dataset at once:
tokenized_swag = swag.map(preprocess_function, batched=True)
🤗 Transformers doesn’t have a data collator for multiple choice, so you will need to create one. You can adapt the DataCollatorWithPadding to create a batch of examples for multiple choice. It will also dynamically pad your text and labels to the length of the longest element in its batch, so they are a uniform length. While it is possible to pad your text in the tokenizer
function by setting padding=True
, dynamic padding is more efficient.
DataCollatorForMultipleChoice
will flatten all the model inputs, apply padding, and then unflatten the results:
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import torch
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="pt",
... )
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
... return batch
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import tensorflow as tf
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="tf",
... )
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
... return batch
Train
Load BERT with AutoModelForMultipleChoice:
>>> from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
>>> model = AutoModelForMultipleChoice.from_pretrained("bert-base-uncased")
If you aren’t familiar with fine-tuning a model with Trainer, take a look at the basic tutorial here!
At this point, only three steps remain:
- Define your training hyperparameters in TrainingArguments.
- Pass the training arguments to Trainer along with the model, dataset, tokenizer, and data collator.
- Call train() to fine-tune your model.
>>> training_args = TrainingArguments(
... output_dir="./results",
... evaluation_strategy="epoch",
... learning_rate=5e-5,
... per_device_train_batch_size=16,
... per_device_eval_batch_size=16,
... num_train_epochs=3,
... weight_decay=0.01,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=tokenized_swag["train"],
... eval_dataset=tokenized_swag["validation"],
... tokenizer=tokenizer,
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
... )
>>> trainer.train()
To fine-tune a model in TensorFlow, start by converting your datasets to the tf.data.Dataset
format with to_tf_dataset
. Specify inputs in columns
, targets in label_cols
, whether to shuffle the dataset order, batch size, and the data collator:
>>> data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
>>> tf_train_set = tokenized_swag["train"].to_tf_dataset(
... columns=["attention_mask", "input_ids"],
... label_cols=["labels"],
... shuffle=True,
... batch_size=batch_size,
... collate_fn=data_collator,
... )
>>> tf_validation_set = tokenized_swag["validation"].to_tf_dataset(
... columns=["attention_mask", "input_ids"],
... label_cols=["labels"],
... shuffle=False,
... batch_size=batch_size,
... collate_fn=data_collator,
... )
If you aren’t familiar with fine-tuning a model with Keras, take a look at the basic tutorial here!
Set up an optimizer function, learning rate schedule, and some training hyperparameters:
>>> from transformers import create_optimizer
>>> batch_size = 16
>>> num_train_epochs = 2
>>> total_train_steps = (len(tokenized_swag["train"]) // batch_size) * num_train_epochs
>>> optimizer, schedule = create_optimizer(init_lr=5e-5, num_warmup_steps=0, num_train_steps=total_train_steps)
Load BERT with TFAutoModelForMultipleChoice:
>>> from transformers import TFAutoModelForMultipleChoice
>>> model = TFAutoModelForMultipleChoice.from_pretrained("bert-base-uncased")
Configure the model for training with compile
:
>>> model.compile(
... optimizer=optimizer,
... loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
... )
Call fit
to fine-tune the model:
>>> model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=2)