FluxTransformer2DModel
A Transformer model for image-like data from Flux.
FluxTransformer2DModel
class diffusers.FluxTransformer2DModel
< source >( patch_size: int = 1 in_channels: int = 64 num_layers: int = 19 num_single_layers: int = 38 attention_head_dim: int = 128 num_attention_heads: int = 24 joint_attention_dim: int = 4096 pooled_projection_dim: int = 768 guidance_embeds: bool = False axes_dims_rope: List = [16, 56, 56] )
Parameters
- patch_size (
int
) — Patch size to turn the input data into small patches. - in_channels (
int
, optional, defaults to 16) — The number of channels in the input. - num_layers (
int
, optional, defaults to 18) — The number of layers of MMDiT blocks to use. - num_single_layers (
int
, optional, defaults to 18) — The number of layers of single DiT blocks to use. - attention_head_dim (
int
, optional, defaults to 64) — The number of channels in each head. - num_attention_heads (
int
, optional, defaults to 18) — The number of heads to use for multi-head attention. - joint_attention_dim (
int
, optional) — The number ofencoder_hidden_states
dimensions to use. - pooled_projection_dim (
int
) — Number of dimensions to use when projecting thepooled_projections
. - guidance_embeds (
bool
, defaults to False) — Whether to use guidance embeddings.
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
forward
< source >( hidden_states: Tensor encoder_hidden_states: Tensor = None pooled_projections: Tensor = None timestep: LongTensor = None img_ids: Tensor = None txt_ids: Tensor = None guidance: Tensor = None joint_attention_kwargs: Optional = None return_dict: bool = True )
Parameters
- hidden_states (
torch.FloatTensor
of shape(batch size, channel, height, width)
) — Inputhidden_states
. - encoder_hidden_states (
torch.FloatTensor
of shape(batch size, sequence_len, embed_dims)
) — Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (
torch.FloatTensor
of shape(batch_size, projection_dim)
) — Embeddings projected from the embeddings of input conditions. - timestep (
torch.LongTensor
) — Used to indicate denoising step. block_controlnet_hidden_states — (list
oftorch.Tensor
): A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (
dict
, optional) — A kwargs dictionary that if specified is passed along to theAttentionProcessor
as defined underself.processor
in diffusers.models.attention_processor. - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a~models.transformer_2d.Transformer2DModelOutput
instead of a plain tuple.
The FluxTransformer2DModel forward method.