Bitsandbytes documentation

8-bit quantization

You are viewing v0.43.3 version. A newer version v0.44.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

8-bit quantization

LLM.int8() is a quantization method that doesn’t degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output.

Linear8bitLt

class bitsandbytes.nn.Linear8bitLt

< >

( input_features: int output_features: int bias = True has_fp16_weights = True memory_efficient_backward = False threshold = 0.0 index = None device = None )

This class is the base module for the LLM.int8() algorithm. To read more about it, have a look at the paper.

In order to quantize a linear layer one should first load the original fp16 / bf16 weights into the Linear8bitLt module, then call int8_module.to("cuda") to quantize the fp16 weights.

Example:

import torch
import torch.nn as nn

import bitsandbytes as bnb
from bnb.nn import Linear8bitLt

fp16_model = nn.Sequential(
    nn.Linear(64, 64),
    nn.Linear(64, 64)
)

int8_model = nn.Sequential(
    Linear8bitLt(64, 64, has_fp16_weights=False),
    Linear8bitLt(64, 64, has_fp16_weights=False)
)

int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here

__init__

< >

( input_features: int output_features: int bias = True has_fp16_weights = True memory_efficient_backward = False threshold = 0.0 index = None device = None )

Parameters

  • input_features (int) — Number of input features of the linear layer.
  • output_features (int) — Number of output features of the linear layer.
  • bias (bool, defaults to True) — Whether the linear class uses the bias term as well.

Initialize Linear8bitLt class.

Int8Params

class bitsandbytes.nn.Int8Params

< >

( data = None requires_grad = True has_fp16_weights = False CB = None SCB = None )

__init__

( *args **kwargs )

Initialize self. See help(type(self)) for accurate signature.

< > Update on GitHub