Mamba
Overview
The Mamba model was proposed in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.
This model is a new paradigm architecture based on state-space-models
. You can read more about the intuition behind these here.
The abstract from the paper is the following:
Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformersβ computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5Γ higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.
Tips:
- Mamba is a new
state space model
architecture that rivals the classic Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of FlashAttention. - Mamba stacks
mixer
layers, which are the equivalent ofAttention
layers. The core logic ofmamba
is held in theMambaMixer
class. - Two implementations cohabit: one is optimized and uses fast cuda kernels, while the other one is naive but can run on any device!
- The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the
mamba-ssm
and thecausal_conv1d
repositories. Make sure to install them if your hardware supports them! - Contributions to make the naive path faster are welcome π€
This model was contributed by ArthurZ. The original code can be found here.
Usage
A simple generation example:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
Peft finetuning
The slow version is not very stable for training, and the fast one needs float32
!
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
MambaConfig
class transformers.MambaConfig
< source >( vocab_size = 50280 hidden_size = 768 state_size = 16 num_hidden_layers = 32 layer_norm_epsilon = 1e-05 pad_token_id = 0 bos_token_id = 0 eos_token_id = 0 expand = 2 conv_kernel = 4 use_bias = False use_conv_bias = True hidden_act = 'silu' initializer_range = 0.1 residual_in_fp32 = True time_step_rank = 'auto' time_step_scale = 1.0 time_step_min = 0.001 time_step_max = 0.1 time_step_init_scheme = 'random' time_step_floor = 0.0001 rescale_prenorm_residual = False use_cache = True use_mambapy = False **kwargs )
Parameters
- vocab_size (
int
, optional, defaults to 50280) — Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by theinputs_ids
passed when calling MambaModel. - hidden_size (
int
, optional, defaults to 768) — Dimensionality of the embeddings and hidden states. - state_size (
int
, optional, defaults to 16) — shape of the state space latents. - num_hidden_layers (
int
, optional, defaults to 32) — Number of hidden layers in the model. - layer_norm_epsilon (
float
, optional, defaults to 1e-05) — The epsilon to use in the layer normalization layers. - pad_token_id (
int
, optional, defaults to 0) — Padding token id. - bos_token_id (
int
, optional, defaults to 0) — The id of the beginning of sentence token in the vocabulary. - eos_token_id (
int
, optional, defaults to 0) — The id of the end of sentence token in the vocabulary. - expand (
int
, optional, defaults to 2) — Expanding factor used to determine the intermediate size. - conv_kernel (
int
, optional, defaults to 4) — Size of the convolution kernel. - use_bias (
bool
, optional, defaults toFalse
) — Whether or not to use bias in [“in_proj”, “out_proj”] of the mixer block - use_conv_bias (
bool
, optional, defaults toTrue
) — Whether or not to use bias in the convolution layer of the mixer block. - hidden_act (
str
, optional, defaults to"silu"
) — The non-linear activation function (function or string) in the decoder. - initializer_range (
float
, optional, defaults to 0.1) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - residual_in_fp32 (
bool
, optional, defaults toTrue
) — Whether or not residuals should be infloat32
. If set toFalse
residuals will keep the samedtype
as the rest of the model - time_step_rank (
Union[int,str]
, optional, defaults to"auto"
) — Rank of the discretization projection matrix."auto"
means that it will default tomath.ceil(self.hidden_size / 16)
- time_step_scale (
float
, optional, defaults to 1.0) — Scale used used to scaledt_proj.bias
. - time_step_min (
float
, optional, defaults to 0.001) — Minimumtime_step
used to bounddt_proj.bias
. - time_step_max (
float
, optional, defaults to 0.1) — Maximumtime_step
used to bounddt_proj.bias
. - time_step_init_scheme (
float
, optional, defaults to"random"
) — Init scheme used fordt_proj.weight
. Should be one of["random","uniform"]
- time_step_floor (
float
, optional, defaults to 0.0001) — Minimum clamping value of thedt_proj.bias
layer initialization. - rescale_prenorm_residual (
bool
, optional, defaults toFalse
) — Whether or not to rescaleout_proj
weights when initializing. - use_cache (
bool
, optional, defaults toTrue
) — Whether or not the cache should be used. - use_mambapy (
bool
, optional, defaults toFalse
) — Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. IfTrue
, the mamba.py implementation is used. IfFalse
, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
This is the configuration class to store the configuration of a MambaModel. It is used to instantiate a MAMBA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the MAMBA state-spaces/mamba-2.8b architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Example:
>>> from transformers import MambaConfig, MambaModel
>>> # Initializing a Mamba configuration
>>> configuration = MambaConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = MambaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
MambaModel
class transformers.MambaModel
< source >( config )
Parameters
- config (MambaConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The bare MAMBA Model transformer outputting raw hidden-states without any specific head on top.
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.LongTensor] = None cache_params: typing.Optional[transformers.cache_utils.MambaCache] = None use_cache: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.LongTensor] = None ) β transformers.models.mamba.modeling_mamba.MambaOutput
or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, input_ids_length)
) — Indices of input sequence tokens in the vocabulary.If
cache_params.seqlen_offset>0
, onlyinput_ids
that do not have their past calculated should be passed asinput_ids
.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - cache_params (
MambaCache
, optional) — If passed along, the model uses the previous state in all the blocks (which will give the output for theinput_ids
provided as if the model addstate_input_ids + input_ids
as context). - use_cache (
bool
, optional) — If set toTrue
, thecache_params
is returned and can be used to quickly generate the next logits. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - cache_position (
torch.LongTensor
of shape(sequence_length)
, optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily toposition_ids
, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.
Returns
transformers.models.mamba.modeling_mamba.MambaOutput
or tuple(torch.FloatTensor)
A transformers.models.mamba.modeling_mamba.MambaOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (MambaConfig) and inputs.
-
last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) β Sequence of hidden-states at the output of the last layer of the model. -
cache_params (
MambaCache
) β The state of the model at the last time step. Can be used in a forward method with the nextinput_ids
to avoid providing the oldinput_ids
.Includes both the State space model state matrices after the selective scan, and the Convolutional states
-
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
The MambaModel forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoTokenizer, MambaModel
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> model = MambaModel.from_pretrained("state-spaces/mamba-130m-hf")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
MambaLMHeadModel
class transformers.MambaForCausalLM
< source >( config )
Parameters
- config (MambaConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None cache_params: typing.Optional[transformers.cache_utils.MambaCache] = None labels: typing.Optional[torch.LongTensor] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None use_cache: typing.Optional[bool] = None cache_position: typing.Optional[torch.Tensor] = None **kwargs ) β transformers.models.mamba.modeling_mamba.MambaCausalLMOutput
or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, input_ids_length)
) — Indices of input sequence tokens in the vocabulary.If
cache_params.seqlen_offset>0
, onlyinput_ids
that do not have their past calculated should be passed asinput_ids
.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - cache_params (
MambaCache
, optional) — If passed along, the model uses the previous state in all the blocks (which will give the output for theinput_ids
provided as if the model addstate_input_ids + input_ids
as context). - use_cache (
bool
, optional) — If set toTrue
, thecache_params
is returned and can be used to quickly generate the next logits. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - cache_position (
torch.LongTensor
of shape(sequence_length)
, optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily toposition_ids
, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. - labels (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can setlabels = input_ids
Indices are selected in[-100, 0, ..., config.vocab_size]
All labels set to-100
are ignored (masked), the loss is only computed for labels in[0, ..., config.vocab_size]
Returns
transformers.models.mamba.modeling_mamba.MambaCausalLMOutput
or tuple(torch.FloatTensor)
A transformers.models.mamba.modeling_mamba.MambaCausalLMOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (MambaConfig) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) β Language modeling loss (for next-token prediction). -
logits (
torch.FloatTensor
of shape(batch_size, sequence_length, config.vocab_size)
) β Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). -
cache_params (
MambaCache
) β The state of the model at the last time step. Can be used in a forward method with the nextinput_ids
to avoid providing the oldinput_ids
.Includes both the State space model state matrices after the selective scan, and the Convolutional states
-
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
The MambaForCausalLM forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> import torch
>>> from transformers import AutoTokenizer, MambaForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> loss = outputs.loss
>>> logits = outputs.logits