import os import platform import torch import traceback import re from easydiffusion.utils import log """ Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). Otherwise the models will load at half-precision (i.e. float16). Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs). """ COMPARABLE_GPU_PERCENTILE = ( 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked ) mem_free_threshold = 0 def get_device_delta(render_devices, active_devices): """ render_devices: 'cpu', or 'auto', or 'mps' or ['cuda:N'...] active_devices: ['cpu', 'mps', 'cuda:N'...] """ if render_devices in ("cpu", "auto", "mps"): render_devices = [render_devices] elif render_devices is not None: if isinstance(render_devices, str): render_devices = [render_devices] if isinstance(render_devices, list) and len(render_devices) > 0: render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices)) if len(render_devices) == 0: raise Exception( 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "mps"} or {"render_devices": "auto"}' ) render_devices = list(filter(lambda x: is_device_compatible(x), render_devices)) if len(render_devices) == 0: raise Exception( "Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion" ) else: raise Exception( 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}' ) else: render_devices = ["auto"] if "auto" in render_devices: render_devices = auto_pick_devices(active_devices) if "cpu" in render_devices: log.warn("WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!") active_devices = set(active_devices) render_devices = set(render_devices) devices_to_start = render_devices - active_devices devices_to_stop = active_devices - render_devices return devices_to_start, devices_to_stop def is_mps_available(): return ( platform.system() == "Darwin" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built() ) def is_cuda_available(): return torch.cuda.is_available() def auto_pick_devices(currently_active_devices): global mem_free_threshold if is_mps_available(): return ["mps"] if not is_cuda_available(): return ["cpu"] device_count = torch.cuda.device_count() if device_count == 1: return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"] log.debug("Autoselecting GPU. Using most free memory.") devices = [] for device in range(device_count): device = f"cuda:{device}" if not is_device_compatible(device): continue mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free /= float(10**9) mem_total /= float(10**9) device_name = torch.cuda.get_device_name(device) log.debug( f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb" ) devices.append({"device": device, "device_name": device_name, "mem_free": mem_free}) devices.sort(key=lambda x: x["mem_free"], reverse=True) max_mem_free = devices[0]["mem_free"] curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold) # Auto-pick algorithm: # 1. Pick the top 75 percentile of the GPUs, sorted by free_mem. # 2. Also include already-running devices (GPU-only), otherwise their free_mem will # always be very low (since their VRAM contains the model). # These already-running devices probably aren't terrible, since they were picked in the past. # Worst case, the user can restart the program and that'll get rid of them. devices = list( filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices) ) devices = list(map(lambda x: x["device"], devices)) return devices def device_init(context, device): """ This function assumes the 'device' has already been verified to be compatible. `get_device_delta()` has already filtered out incompatible devices. """ validate_device_id(device, log_prefix="device_init") if "cuda" not in device: context.device = device context.device_name = get_processor_name() context.half_precision = False log.debug(f"Render device available as {context.device_name}") return context.device_name = torch.cuda.get_device_name(device) context.device = device # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images if needs_to_force_full_precision(context): log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}") # Apply force_full_precision now before models are loaded. context.half_precision = False log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}') torch.cuda.device(device) def needs_to_force_full_precision(context): if "FORCE_FULL_PRECISION" in os.environ: return True device_name = context.device_name.lower() return ( ("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name) and ( " 1660" in device_name or " 1650" in device_name or " t400" in device_name or " t550" in device_name or " t600" in device_name or " t1000" in device_name or " t1200" in device_name or " t2000" in device_name ) ) or ("tesla k40m" in device_name) def get_max_vram_usage_level(device): if "cuda" in device: _, mem_total = torch.cuda.mem_get_info(device) else: return "high" mem_total /= float(10**9) if mem_total < 4.5: return "low" elif mem_total < 6.5: return "balanced" return "high" def validate_device_id(device, log_prefix=""): def is_valid(): if not isinstance(device, str): return False if device == "cpu" or device == "mps": return True if not device.startswith("cuda:") or not device[5:].isnumeric(): return False return True if not is_valid(): raise EnvironmentError( f"{log_prefix}: device id should be 'cpu', 'mps', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}" ) def is_device_compatible(device): """ Returns True/False, and prints any compatibility errors """ # static variable "history". is_device_compatible.history = getattr(is_device_compatible, "history", {}) try: validate_device_id(device, log_prefix="is_device_compatible") except: log.error(str(e)) return False if device in ("cpu", "mps"): return True # Memory check try: _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) if mem_total < 3.0: if is_device_compatible.history.get(device) == None: log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion") is_device_compatible.history[device] = 1 return False except RuntimeError as e: log.error(str(e)) return False return True def get_processor_name(): try: import subprocess if platform.system() == "Windows": return platform.processor() elif platform.system() == "Darwin": os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin" command = "sysctl -n machdep.cpu.brand_string" return subprocess.check_output(command, shell=True).decode().strip() elif platform.system() == "Linux": command = "cat /proc/cpuinfo" all_info = subprocess.check_output(command, shell=True).decode().strip() for line in all_info.split("\n"): if "model name" in line: return re.sub(".*model name.*:", "", line, 1).strip() except: log.error(traceback.format_exc()) return "cpu"