8-bit quantization
LLM.int8() is a quantization method that doesn’t degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output.
Linear8bitLt
class bitsandbytes.nn.Linear8bitLt
< source >( input_features: int output_features: int bias = True has_fp16_weights = True memory_efficient_backward = False threshold = 0.0 index = None device = None )
This class is the base module for the LLM.int8() algorithm. To read more about it, have a look at the paper.
In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call int8_module.to("cuda")
to quantize the fp16 weights.
Example:
import torch
import torch.nn as nn
import bitsandbytes as bnb
from bnb.nn import Linear8bitLt
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)
int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)
int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
__init__
< source >( input_features: int output_features: int bias = True has_fp16_weights = True memory_efficient_backward = False threshold = 0.0 index = None device = None )
Initialize Linear8bitLt class.
Int8Params
class bitsandbytes.nn.Int8Params
< source >( data = None requires_grad = True has_fp16_weights = False CB = None SCB = None )