"""Utility functions for training and inference.""" import functools from pathlib import Path import pickle import warnings from io import BytesIO import torch import torch.utils._device from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType def save_model_checkpoint(fabric, model, file_path): """Handles boilerplate logic for retrieving and saving the state_dict. This will be upstreamed to Fabric soon. """ file_path = Path(file_path) if isinstance(fabric.strategy, DeepSpeedStrategy): from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict fabric.save(file_path, {"model": model}) fabric.barrier() if fabric.global_rank == 0: # Create a consolidated checkpoint with the same name next to the deepspeed checkpoint convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth")) return if isinstance(fabric.strategy, FSDPStrategy): save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): state_dict = model._forward_module.state_dict() else: state_dict = model.state_dict() if fabric.global_rank == 0: torch.save(state_dict, file_path) fabric.barrier() class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): def __init__(self, device=None, dtype=None, quantization_mode=None): """ Create tensors with given device and dtype and don't run initialization (but instead use "empty tensors", i.e. uninitialized memory). device: `torch.device` to work with dtype: `torch.dtype` to work with quantization_mode: optional string, quantization mode to work with, default `None`. Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU) `qptq.int4`, `gptq.int8`: GPTQ pre-quantized models Example:: with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): model = LLaMA.from_name('7B') model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))""" self.quantization_mode = quantization_mode self.quantized_linear_cls = None if self.quantization_mode == 'llm.int8': if device.type != "cuda": raise ValueError("Quantization is only supported on the GPU.") from .quantization import Linear8bitLt self.quantized_linear_cls = Linear8bitLt elif self.quantization_mode == 'gptq.int4': from .quantization import ColBlockQuantizedLinear self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1) elif self.quantization_mode == 'gptq.int8': from .quantization import ColBlockQuantizedLinear self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1) elif self.quantization_mode is not None: raise RuntimeError(f"unknown quantization mode {self.quantization_mode}") self.device = device self.dtype = dtype def __enter__(self): if self.quantized_linear_cls != None: self.torch_linear_cls = torch.nn.Linear torch.nn.Linear = self.quantized_linear_cls return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): if self.quantized_linear_cls != None: torch.nn.Linear = self.torch_linear_cls return super().__exit__(exc_type, exc_val, exc_tb) def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} if getattr(func, "__module__", None) == "torch.nn.init": if "tensor" in kwargs: return kwargs["tensor"] else: return args[0] if ( self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get("device") is None ): kwargs["device"] = self.device if ( self.dtype is not None and func in torch.utils._device._device_constructors() and kwargs.get("dtype") is None ): kwargs["dtype"] = self.dtype return func(*args, **kwargs) # this is taken from torchhacks https://github.com/lernapparat/torchhacks class NotYetLoadedTensor: def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): self.metatensor = metatensor self.archiveinfo = archiveinfo self.storageinfo = storageinfo self.rebuild_args = rebuild_args @classmethod def rebuild( cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, archiveinfo=None, ): rebuild_args = ( storage_offset, size, stride, requires_grad, backward_hooks, metadata, ) metatensor = torch._utils._rebuild_tensor_v2( storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata, ) storageinfo = storage.archiveinfo return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) def _load_tensor(self): name, storage_cls, fn, device, size = self.storageinfo dtype = self.metatensor.dtype uts = ( self.archiveinfo.zipfile.get_storage_from_record( f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage, ) ._typed_storage() ._untyped_storage ) with warnings.catch_warnings(): warnings.simplefilter("ignore") storage = torch.storage.TypedStorage( wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True ) tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} loaded_args = [ (a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args ] res = func(*loaded_args, **kwargs) # gc.collect would be costly here, maybe do it optionally return res def __getattr__(self, name): # properties ## TODO: device, is_...?? ## TODO: mH, mT, H, T, data, imag, real ## name ??? if name in { "dtype", "grad", "grad_fn", "layout", "names", "ndim", "output_nr", "requires_grad", "retains_grad", "shape", "volatile", }: return getattr(self.metatensor, name) if name in {"size"}: return getattr(self.metatensor, name) # materializing with contiguous is needed for quantization if name in {"contiguous"}: return getattr(self._load_tensor(), name) raise AttributeError(f"{type(self)} does not have {name}") def __repr__(self): return f"NotYetLoadedTensor({repr(self.metatensor)})" class LazyLoadingUnpickler(pickle.Unpickler): def __init__(self, file, zipfile): super().__init__(file) self.zipfile = zipfile def find_class(self, module, name): if module == "torch._utils" and name == "_rebuild_tensor_v2": res = super().find_class(module, name) return functools.partial(NotYetLoadedTensor.rebuild, archiveinfo=self) return super().find_class(module, name) def persistent_load(self, pid): name, cls, fn, device, size = pid with warnings.catch_warnings(): warnings.simplefilter("ignore") s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") s.archiveinfo = pid return s def lazy_load(fn): zf = torch._C.PyTorchFileReader(str(fn)) with BytesIO(zf.get_record("data.pkl")) as pkl: mup = LazyLoadingUnpickler(pkl, zf) sd = mup.load() return sd