Iterative Trainer
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
Usage
To get started quickly, instantiate an instance a model, and a tokenizer.
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
trainer = IterativeSFTTrainer(
model,
tokenizer
)
You have the choice to either provide a list of strings or a list of tensors to the step function.
Using a list of tensors as input:
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
trainer.step(**inputs)
Using a list of strings as input:
inputs = {
"texts": texts
}
trainer.step(**inputs)
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
IterativeTrainer
class trl.IterativeSFTTrainer
< source >( model: Optional = None args: Optional = None processing_class: Union = None optimizers: Tuple = (None, None) data_collator: Optional = None eval_dataset: Union = None max_length: Optional = None truncation_mode: Optional = 'keep_end' preprocess_logits_for_metrics: Optional = None compute_metrics: Optional = None optimize_device_cache: Optional = False )
Parameters
- model (
PreTrainedModel
) — Model to be optimized, either an ‘AutoModelForCausalLM’ or an ‘AutoModelForSeq2SeqLM’. Check the documentation ofPreTrainedModel
for more details. - args (
transformers.TrainingArguments
) — The arguments to use for training. - processing_class (
PreTrainedTokenizerBase
orBaseImageProcessor
orFeatureExtractionMixin
orProcessorMixin
, optional) — Processing class used to process the data. If provided, will be used to automatically process the inputs for the model, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. - optimizers (
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — The optimizer and scheduler to use for training. - data_collator (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], optional) — Data collator to be used for training and passed along the dataloader.
- eval_dataset (
datasets.Dataset
) — The dataset to use for evaluation. - max_length (
int
, defaults toNone
) — The maximum length of the input. - truncation_mode (
str
, defaults tokeep_end
) — The truncation mode to use, eitherkeep_end
orkeep_start
. - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — The function to use to preprocess the logits before computing the metrics. - compute_metrics (
Callable[[EvalPrediction], Dict]
, optional) — The function to use to compute the metrics. Must take aEvalPrediction
and return a dictionary string to metric values. - optimize_device_cache (
bool
, optional, defaults toFalse
) — Optimize CUDA cache for slightly more memory-efficient training.
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
create_model_card
< source >( model_name: Optional = None dataset_name: Optional = None tags: Union = None )
Creates a draft of a model card using the information available to the Trainer
.
step
< source >( input_ids: Optional = None attention_mask: Optional = None labels: Optional = None texts: Optional = None texts_labels: Optional = None ) → dict[str, Any]
Parameters
- input_ids (List
torch.LongTensor
) — List of tensors containing the input_ids (if not provided, text will be used) - attention_mask (List
torch.LongTensor
, , optional) — List of tensors containing the attention_mask - labels (List
torch.FloatTensor
, optional) — List of tensors containing the labels (if set to None, will default to input_ids) - texts (List
str
, optional) — List of strings containing the text input (if not provided, input_ids will directly be used) - texts_labels (List
str
, optional) — List of strings containing the text labels (if set to None, will default to text)
Returns
dict[str, Any]
A summary of the training statistics
Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.