Transformers documentation

LLaVa

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.46.3).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

LLaVa

Overview

LLaVa is an open-source chatbot trained by fine-tuning LlamA/Vicuna on GPT-generated multimodal instruction-following data. It is an auto-regressive language model, based on the transformer architecture. In other words, it is an multi-modal version of LLMs fine-tuned for chat / instructions.

The LLaVa model was proposed in Visual Instruction Tuning and improved in Improved Baselines with Visual Instruction Tuning by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee.

The abstract from the paper is the following:

Large multimodal models (LMM) have recently shown encouraging progress with visual instruction tuning. In this note, we show that the fully-connected vision-language cross-modal connector in LLaVA is surprisingly powerful and data-efficient. With simple modifications to LLaVA, namely, using CLIP-ViT-L-336px with an MLP projection and adding academic-task-oriented VQA data with simple response formatting prompts, we establish stronger baselines that achieve state-of-the-art across 11 benchmarks. Our final 13B checkpoint uses merely 1.2M publicly available data, and finishes full training in ∼1 day on a single 8-A100 node. We hope this can make state-of-the-art LMM research more accessible. Code and model will be publicly available

drawing LLaVa architecture. Taken from the original paper.

This model was contributed by ArthurZ and ybelkada. The original code can be found here.

Usage tips

  • We advise users to use padding_side="left" when computing batched generation as it leads to more accurate results. Simply make sure to call processor.tokenizer.padding_side = "left" before generating.

  • Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.

[!NOTE] LLaVA models after release v4.46 will raise warnings about adding processor.patch_size = {{patch_size}}, processor.num_additional_image_tokens = {{num_additional_image_tokens}} and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many <image>placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings. The attributes can be obtained from model config, asmodel.config.vision_config.patch_sizeormodel.config.vision_feature_select_strategy. The num_additional_image_tokensshould be1if the vision backbone adds a CLS token or0` if nothing extra is added to the vision patches.

Single image inference

For best results, we recommend users to use the processor’s apply_chat_template() method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys “role” and “content”. The “content” should be a list of dictionaries, for “text” and “image” modalities, as follows:

from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What’s shown in this image?"},
            ],
    },
    {
        "role": "assistant",
        "content": [{"type": "text", "text": "This image shows a red stop sign."},]
    },
    {

        "role": "user",
        "content": [
            {"type": "text", "text": "Describe the image in more details."},
        ],
    },
]

text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images
print(text_prompt)
>>> "USER: <image>\n<What’s shown in this image? ASSISTANT: This image shows a red stop sign.</s>USER: Describe the image in more details. ASSISTANT:"

Batched inference

LLaVa also supports batched inference. Here is how you can do it:

import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Load the model in half-precision
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16, device_map="auto")
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# Get two different images
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image_stop = Image.open(requests.get(url, stream=True).raw)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_cats = Image.open(requests.get(url, stream=True).raw)

# Prepare a batch of two prompts
conversation_1 = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What is shown in this image?"},
        ],
    },
]

conversation_2 = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What is shown in this image?"},
        ],
    },
]

prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True)
prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True)
prompts = [prompt_1, prompt_2]

# We can simply feed images in the order they have to be used in the text prompt
inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device, torch.float16)

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=30)
processor.batch_decode(generate_ids, skip_special_tokens=True)
  • If you want to construct a chat prompt yourself, below is a list of prompt formats accepted by each llava checkpoint:

llava-interleave models requires the following format:

"<|im_start|>user <image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant"

For multiple turns conversation:

"<|im_start|>user <image>\n<prompt1><|im_end|><|im_start|>assistant <answer1><|im_end|><|im_start|>user <image>\n<prompt1><|im_end|><|im_start|>assistant "

llava-1.5 models requires the following format:

"USER: <image>\n<prompt> ASSISTANT:"

For multiple turns conversation:

"USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:"

Using Flash Attention 2

Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the Flash Attention 2 section of performance docs.

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BEiT.

Image-to-Text

LlavaConfig

class transformers.LlavaConfig

< >

( vision_config = None text_config = None ignore_index = -100 image_token_index = 32000 projector_hidden_act = 'gelu' vision_feature_select_strategy = 'default' vision_feature_layer = -2 image_seq_length = 576 **kwargs )

Parameters

  • vision_config (Union[AutoConfig, dict], optional, defaults to CLIPVisionConfig) — The config object or dictionary of the vision backbone.
  • text_config (Union[AutoConfig, dict], optional, defaults to LlamaConfig) — The config object or dictionary of the text backbone.
  • ignore_index (int, optional, defaults to -100) — The ignore index for the loss function.
  • image_token_index (int, optional, defaults to 32000) — The image token index to encode the image prompt.
  • projector_hidden_act (str, optional, defaults to "gelu") — The activation function used by the multimodal projector.
  • vision_feature_select_strategy (str, optional, defaults to "default") — The feature selection strategy used to select the vision feature from the vision backbone. Can be one of "default" or "full".
  • vision_feature_layer (int, optional, defaults to -2) — The index of the layer to select the vision feature.
  • image_seq_length (int, optional, defaults to 576) — Sequence length of one image embedding.

This is the configuration class to store the configuration of a LlavaForConditionalGeneration. It is used to instantiate an Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Llava-9B.

e.g. llava-hf/llava-9b

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig

>>> # Initializing a CLIP-vision config
>>> vision_config = CLIPVisionConfig()

>>> # Initializing a Llama config
>>> text_config = LlamaConfig()

>>> # Initializing a Llava llava-1.5-7b style configuration
>>> configuration = LlavaConfig(vision_config, text_config)

>>> # Initializing a model from the llava-1.5-7b style configuration
>>> model = LlavaForConditionalGeneration(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

LlavaProcessor

class transformers.LlavaProcessor

< >

( image_processor = None tokenizer = None patch_size = None vision_feature_select_strategy = None chat_template = None image_token = '<image>' num_additional_image_tokens = 0 **kwargs )

Parameters

  • image_processor (CLIPImageProcessor, optional) — The image processor is a required input.
  • tokenizer (LlamaTokenizerFast, optional) — The tokenizer is a required input.
  • patch_size (int, optional) — Patch size from the vision tower.
  • vision_feature_select_strategy (str, optional) — The feature selection strategy used to select the vision feature from the vision backbone. Shoudl be same as in model’s config
  • chat_template (str, optional) — A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
  • image_token (str, optional, defaults to "<image>") — Special token used to denote image location.
  • num_additional_image_tokens (int, optional, defaults to 0) — Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other extra tokens appended, no need to set this arg.

Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.

LlavaProcessor offers all the functionalities of CLIPImageProcessor and LlamaTokenizerFast. See the __call__() and decode() for more information.

batch_decode

< >

( *args **kwargs )

This method forwards all its arguments to LlamaTokenizerFast’s batch_decode(). Please refer to the docstring of this method for more information.

decode

< >

( *args **kwargs )

This method forwards all its arguments to LlamaTokenizerFast’s decode(). Please refer to the docstring of this method for more information.

LlavaForConditionalGeneration

class transformers.LlavaForConditionalGeneration

< >

( config: LlavaConfig )

Parameters

  • config (LlavaConfig or LlavaVisionConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The LLAVA model which consists of a vision backbone and a language model. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( input_ids: LongTensor = None pixel_values: FloatTensor = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[typing.List[torch.FloatTensor]] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None vision_feature_layer: typing.Optional[int] = None vision_feature_select_strategy: typing.Optional[str] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None num_logits_to_keep: int = 0 ) transformers.models.llava.modeling_llava.LlavaCausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    What are input IDs?

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size)) -- The tensors corresponding to the input images. Pixel values can be obtained using [AutoImageProcessor](/docs/transformers/main/en/model_doc/auto#transformers.AutoImageProcessor). See [CLIPImageProcessor.__call__()](/docs/transformers/main/en/model_doc/imagegpt#transformers.ImageGPTFeatureExtractor.__call__) for details ([]LlavaProcessor`] uses CLIPImageProcessor for processing images).
  • attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:

    • 1 for tokens that are not masked,
    • 0 for tokens that are masked.

    What are attention masks?

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    If past_key_values is used, optionally only the last decoder_input_ids have to be input (see past_key_values).

    If you want to change padding behavior, you should read modeling_opt._prepare_decoder_attention_mask and modify to your needs. See diagram 1 in the paper for more information on the default strategy.

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0, config.n_positions - 1]. What are position IDs?
  • past_key_values (tuple(tuple(torch.FloatTensor)), optional, returned when use_cache=True is passed or when config.use_cache=True) — Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)) and 2 additional tensors of shape (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).

    Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

    If past_key_values are used, the user can optionally input only the last decoder_input_ids (those that don’t have their past key value states given to this model) of shape (batch_size, 1) instead of all decoder_input_ids of shape (batch_size, sequence_length).

  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.
  • vision_feature_layer (int, optional, defaults to -2) — The index of the layer to select the vision feature.
  • vision_feature_select_strategy (str, optional, defaults to "default") — The feature selection strategy used to select the vision feature from the vision backbone. Can be one of "default" or "full".
  • use_cache (bool, optional) — If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • cache_position (torch.LongTensor of shape (sequence_length), optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily to position_ids, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.
  • Args — labels (torch.LongTensor of shape (batch_size, sequence_length), optional): Labels for computing the masked language modeling loss. Indices should either be in [0, ..., config.vocab_size] or -100 (see input_ids docstring). Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.vocab_size].

    num_logits_to_keep (int, optional): Calculate logits for the last num_logits_to_keep tokens. If 0, calculate logits for all input_ids (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size.

Returns

transformers.models.llava.modeling_llava.LlavaCausalLMOutputWithPast or tuple(torch.FloatTensor)

A transformers.models.llava.modeling_llava.LlavaCausalLMOutputWithPast or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LlavaConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) — Language modeling loss (for next-token prediction).

  • logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (tuple(tuple(torch.FloatTensor)), optional, returned when use_cache=True is passed or when config.use_cache=True) — Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))

    Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • image_hidden_states (torch.FloatTensor, optional) — A torch.FloatTensor of size (batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.

The LlavaForConditionalGeneration forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration

>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = processor(images=image, text=prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER:  \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
< > Update on GitHub