File size: 8,946 Bytes
a6ec9cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os
import platform
import torch
import traceback
import re

from easydiffusion.utils import log

"""
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
Otherwise the models will load at half-precision (i.e. float16).

Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
"""

COMPARABLE_GPU_PERCENTILE = (
    0.65  # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
)

mem_free_threshold = 0


def get_device_delta(render_devices, active_devices):
    """
    render_devices: 'cpu', or 'auto', or 'mps' or ['cuda:N'...]
    active_devices: ['cpu', 'mps', 'cuda:N'...]
    """

    if render_devices in ("cpu", "auto", "mps"):
        render_devices = [render_devices]
    elif render_devices is not None:
        if isinstance(render_devices, str):
            render_devices = [render_devices]
        if isinstance(render_devices, list) and len(render_devices) > 0:
            render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices))
            if len(render_devices) == 0:
                raise Exception(
                    'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "mps"} or {"render_devices": "auto"}'
                )

            render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
            if len(render_devices) == 0:
                raise Exception(
                    "Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion"
                )
        else:
            raise Exception(
                'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
            )
    else:
        render_devices = ["auto"]

    if "auto" in render_devices:
        render_devices = auto_pick_devices(active_devices)
        if "cpu" in render_devices:
            log.warn("WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!")

    active_devices = set(active_devices)
    render_devices = set(render_devices)

    devices_to_start = render_devices - active_devices
    devices_to_stop = active_devices - render_devices

    return devices_to_start, devices_to_stop


def is_mps_available():
    return (
        platform.system() == "Darwin"
        and hasattr(torch.backends, "mps")
        and torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
    )


def is_cuda_available():
    return torch.cuda.is_available()


def auto_pick_devices(currently_active_devices):
    global mem_free_threshold

    if is_mps_available():
        return ["mps"]

    if not is_cuda_available():
        return ["cpu"]

    device_count = torch.cuda.device_count()
    if device_count == 1:
        return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"]

    log.debug("Autoselecting GPU. Using most free memory.")
    devices = []
    for device in range(device_count):
        device = f"cuda:{device}"
        if not is_device_compatible(device):
            continue

        mem_free, mem_total = torch.cuda.mem_get_info(device)
        mem_free /= float(10**9)
        mem_total /= float(10**9)
        device_name = torch.cuda.get_device_name(device)
        log.debug(
            f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
        )
        devices.append({"device": device, "device_name": device_name, "mem_free": mem_free})

    devices.sort(key=lambda x: x["mem_free"], reverse=True)
    max_mem_free = devices[0]["mem_free"]
    curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free
    mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold)

    # Auto-pick algorithm:
    # 1. Pick the top 75 percentile of the GPUs, sorted by free_mem.
    # 2. Also include already-running devices (GPU-only), otherwise their free_mem will
    #    always be very low (since their VRAM contains the model).
    #    These already-running devices probably aren't terrible, since they were picked in the past.
    #    Worst case, the user can restart the program and that'll get rid of them.
    devices = list(
        filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices)
    )
    devices = list(map(lambda x: x["device"], devices))
    return devices


def device_init(context, device):
    """
    This function assumes the 'device' has already been verified to be compatible.
    `get_device_delta()` has already filtered out incompatible devices.
    """

    validate_device_id(device, log_prefix="device_init")

    if "cuda" not in device:
        context.device = device
        context.device_name = get_processor_name()
        context.half_precision = False
        log.debug(f"Render device available as {context.device_name}")
        return

    context.device_name = torch.cuda.get_device_name(device)
    context.device = device

    # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
    if needs_to_force_full_precision(context):
        log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}")
        # Apply force_full_precision now before models are loaded.
        context.half_precision = False

    log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}')
    torch.cuda.device(device)


def needs_to_force_full_precision(context):
    if "FORCE_FULL_PRECISION" in os.environ:
        return True

    device_name = context.device_name.lower()
    return (
        ("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name)
        and (
            " 1660" in device_name
            or " 1650" in device_name
            or " t400" in device_name
            or " t550" in device_name
            or " t600" in device_name
            or " t1000" in device_name
            or " t1200" in device_name
            or " t2000" in device_name
        )
    ) or ("tesla k40m" in device_name)


def get_max_vram_usage_level(device):
    if "cuda" in device:
        _, mem_total = torch.cuda.mem_get_info(device)
    else:
        return "high"

    mem_total /= float(10**9)
    if mem_total < 4.5:
        return "low"
    elif mem_total < 6.5:
        return "balanced"

    return "high"


def validate_device_id(device, log_prefix=""):
    def is_valid():
        if not isinstance(device, str):
            return False
        if device == "cpu" or device == "mps":
            return True
        if not device.startswith("cuda:") or not device[5:].isnumeric():
            return False
        return True

    if not is_valid():
        raise EnvironmentError(
            f"{log_prefix}: device id should be 'cpu', 'mps', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
        )


def is_device_compatible(device):
    """
    Returns True/False, and prints any compatibility errors
    """
    # static variable "history".
    is_device_compatible.history = getattr(is_device_compatible, "history", {})
    try:
        validate_device_id(device, log_prefix="is_device_compatible")
    except:
        log.error(str(e))
        return False

    if device in ("cpu", "mps"):
        return True
    # Memory check
    try:
        _, mem_total = torch.cuda.mem_get_info(device)
        mem_total /= float(10**9)
        if mem_total < 3.0:
            if is_device_compatible.history.get(device) == None:
                log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")
                is_device_compatible.history[device] = 1
            return False
    except RuntimeError as e:
        log.error(str(e))
        return False
    return True


def get_processor_name():
    try:
        import subprocess

        if platform.system() == "Windows":
            return platform.processor()
        elif platform.system() == "Darwin":
            os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
            command = "sysctl -n machdep.cpu.brand_string"
            return subprocess.check_output(command, shell=True).decode().strip()
        elif platform.system() == "Linux":
            command = "cat /proc/cpuinfo"
            all_info = subprocess.check_output(command, shell=True).decode().strip()
            for line in all_info.split("\n"):
                if "model name" in line:
                    return re.sub(".*model name.*:", "", line, 1).strip()
    except:
        log.error(traceback.format_exc())
        return "cpu"