import os import random from os import path from contextlib import nullcontext import time from sys import platform import torch cache_path = path.join(path.dirname(path.abspath(__file__)), "models") os.environ["TRANSFORMERS_CACHE"] = cache_path os.environ["HF_HUB_CACHE"] = cache_path os.environ["HF_HOME"] = cache_path is_mac = platform == "darwin" def should_use_fp16(): if is_mac: return True gpu_props = torch.cuda.get_device_properties("cuda") if gpu_props.major < 6: return False nvidia_16_series = ["1660", "1650", "1630"] for x in nvidia_16_series: if x in gpu_props.name: return False return True class timer: def __init__(self, method_name="timed process"): self.method = method_name def __enter__(self): self.start = time.time() print(f"{self.method} starts") def __exit__(self, exc_type, exc_val, exc_tb): end = time.time() print(f"{self.method} took {str(round(end - self.start, 2))}s") def load_models(model_id="stabilityai/stable-diffusion-xl-base-1.0"): from diffusers import UNet2DConditionModel, AutoPipelineForImage2Image, LCMScheduler from diffusers.utils import load_image if not is_mac: torch.backends.cuda.matmul.allow_tf32 = True use_fp16 = should_use_fp16() lora_id = "artificialguybr/LogoRedmond-LogoLoraForSDXL-V2" unet = UNet2DConditionModel.from_pretrained("latent-consistency/lcm-sdxl", torch_dtype=torch.float16, variant="fp16") if use_fp16: pipe = AutoPipelineForImage2Image.from_pretrained( model_id, unet=unet, cache_dir=cache_path, torch_dtype=torch.float16, variant="fp16", safety_checker=None ) else: pipe = AutoPipelineForImage2Image.from_pretrained( model_id, unet=unet, cache_dir=cache_path, safety_checker=None ) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(lora_id) pipe.fuse_lora() device = "mps" if is_mac else "cuda" pipe.to(device=device) generator = torch.Generator() def infer( prompt, image, num_inference_steps=4, guidance_scale=1, strength=0.9, seed=random.randrange(0, 2**63) ): with torch.inference_mode(): with torch.autocast("cuda") if device == "cuda" else nullcontext(): with timer("inference"): return pipe( prompt=prompt, image=load_image(image), generator=generator.manual_seed(seed), num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, strength=strength ).images[0] return infer