# ztrain/util.py # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted import contextlib import torch @contextlib.contextmanager def cuda_memory_profiler(display : str = True): """ A context manager for profiling CUDA memory usage in PyTorch. """ if display is False: yield return if not torch.cuda.is_available(): print("CUDA is not available, skipping memory profiling") yield return torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() start_memory = torch.cuda.memory_allocated() try: yield finally: torch.cuda.synchronize() end_memory = torch.cuda.memory_allocated() print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / (1024 ** 2):.2f} MB") print(f"Memory allocated at start: {start_memory / (1024 ** 2):.2f} MB") print(f"Memory allocated at end: {end_memory / (1024 ** 2):.2f} MB") print(f"Net memory change: {(end_memory - start_memory) / (1024 ** 2):.2f} MB") def get_device(): return torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")