NeuronTrainer
The NeuronTrainer
class provides an extended API for the feature-complete Transformers Trainer. It is used in all the example scripts.
The NeuronTrainer
class is optimized for 🤗 Transformers models running on AWS Trainium.
Here is an example of how to customize NeuronTrainer
to use a weighted loss (useful when you have an unbalanced training set):
from torch import nn
from optimum.neuron import NeuronTrainer
class CustomNeuronTrainer(NeuronTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
Another way to customize the training loop behavior for the PyTorch NeuronTrainer
is to use callbacks that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).
NeuronTrainer
Trainer that is suited for performing training on AWS Tranium instances.
Seq2SeqTrainer that is suited for performing training on AWS Tranium instances.