import argparse import torch import logging from library.utils import setup_logging from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, LCMScheduler from library.sdxl_model_util import convert_diffusers_unet_state_dict_to_sdxl, sdxl_original_unet, save_stable_diffusion_checkpoint, _load_state_dict_on_device as load_state_dict_on_device from accelerate import init_empty_weights # Initialize logging setup_logging() logger = logging.getLogger(__name__) def parse_command_line_arguments(): argument_parser = argparse.ArgumentParser("lcm_convert") argument_parser.add_argument("--name", help="Name of the new LCM model", required=True, type=str) argument_parser.add_argument("--model", help="A model to convert", required=True, type=str) argument_parser.add_argument("--lora-scale", default=1.0, help="Strength of the LCM", type=float) argument_parser.add_argument("--sdxl", action="store_true", help="Use SDXL models") argument_parser.add_argument("--ssd-1b", action="store_true", help="Use SSD-1B models") return argument_parser.parse_args() def load_diffusion_pipeline(command_line_args): if command_line_args.sdxl or command_line_args.ssd_1b: return StableDiffusionXLPipeline.from_single_file(command_line_args.model) else: return StableDiffusionPipeline.from_single_file(command_line_args.model) def convert_and_save_diffusion_model(diffusion_pipeline, command_line_args): diffusion_pipeline.scheduler = LCMScheduler.from_config(diffusion_pipeline.scheduler.config) lora_weight_file_path = "latent-consistency/lcm-lora-" + ("sdxl" if command_line_args.sdxl else "ssd-1b" if command_line_args.ssd_1b else "sdv1-5") diffusion_pipeline.load_lora_weights(lora_weight_file_path) diffusion_pipeline.fuse_lora(lora_scale=command_line_args.lora_scale) diffusion_pipeline = diffusion_pipeline.to(dtype=torch.float16) logger.info("Saving file...") text_encoder_primary = diffusion_pipeline.text_encoder text_encoder_secondary = diffusion_pipeline.text_encoder_2 variational_autoencoder = diffusion_pipeline.vae unet_network = diffusion_pipeline.unet del diffusion_pipeline state_dict = convert_diffusers_unet_state_dict_to_sdxl(unet_network.state_dict()) with init_empty_weights(): unet_network = sdxl_original_unet.SdxlUNet2DConditionModel() load_state_dict_on_device(unet_network, state_dict, device="cuda", dtype=torch.float16) save_stable_diffusion_checkpoint( command_line_args.name, text_encoder_primary, text_encoder_secondary, unet_network, None, None, None, variational_autoencoder, None, None, torch.float16, ) logger.info("...done saving") def main(): command_line_args = parse_command_line_arguments() try: diffusion_pipeline = load_diffusion_pipeline(command_line_args) convert_and_save_diffusion_model(diffusion_pipeline, command_line_args) except Exception as error: logger.error(f"An error occurred: {error}") if __name__ == "__main__": main()