Number of classes
I would like to run this model with more than 10 classes. I read in another discussion that this is possible if I run the model locally. How can I do this?
I have same problem
I am having the same problem
This can be done as follows. Let's use a dummy dataset for demonstration purposes.
Let's say we want to zero-shot classify all the texts present in the stanfordnlp/imdb dataset.
The way models like "facebook/bart-large-mnli" work is by simply binary classifying the text with each possible candidate label (hence we have as many text-label pairs as we have candidate labels). The model then predicts whether the candidate label is entailed, neutral or contradicted with respect to the text.
from datasets import load_dataset
from transformers import AutoTokenizer
model_name = "facebook/bart-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("stanfordnlp/imdb")
# let's say we have 3 candidate labels
candidate_labels = ["sports", "science", "politics"]
def expand_dataset(example):
all_texts = []
all_candidate_labels = []
for text in example["text"]:
for label in candidate_labels:
all_texts.append(text)
all_candidate_labels.append(label)
inputs = tokenizer(all_texts, all_candidate_labels, truncation=True, padding="max_length", return_tensors="pt")
return inputs
# expand the dataset by converting each text into a set of (text, candidate label) pairs for the model
dataset = dataset["train"].map(expand_dataset, batched=True, remove_columns=dataset["train"].column_names)
Once we have prepared the data for the model, we can run a forward pass in batches (set the batch size as high as possible on your given hardware):
import torch
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
@torch
.no_grad()
def predict_labels(batch):
# move batch to device
batch = {k: v.to(device) for k,v in batch.items()}
# forward pass
outputs = model(**batch)
# get the logits
logits = outputs.logits
# add predicted label (either "entailment", "neutral" or "contradicted" to dataset
predicted_labels = outputs.logits.argmax(-1).tolist()
predicted_labels = [model.config.id2label[id] for id in predicted_labels]
batch["labels"] = predicted_labels
return batch
dataset = dataset.map(predict_labels, batched=True, batch_size=4)