import gradio as gr import os import argparse from kohya_gui.class_gui_config import KohyaSSGUIConfig from kohya_gui.dreambooth_gui import dreambooth_tab from kohya_gui.finetune_gui import finetune_tab from kohya_gui.textual_inversion_gui import ti_tab from kohya_gui.utilities import utilities_tab from kohya_gui.lora_gui import lora_tab from kohya_gui.class_lora_tab import LoRATools from kohya_gui.custom_logging import setup_logging from kohya_gui.localization_ext import add_javascript def UI(**kwargs): add_javascript(kwargs.get("language")) css = "" headless = kwargs.get("headless", False) log.info(f"headless: {headless}") if os.path.exists("./assets/style.css"): with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: log.debug("Load CSS...") css += file.read() + "\n" if os.path.exists("./.release"): with open(os.path.join("./.release"), "r", encoding="utf8") as file: release = file.read() if os.path.exists("./README.md"): with open(os.path.join("./README.md"), "r", encoding="utf8") as file: README = file.read() interface = gr.Blocks( css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default() ) config = KohyaSSGUIConfig(config_file_path=kwargs.get("config")) if config.is_config_loaded(): log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...") use_shell_flag = True # if os.name == "posix": # use_shell_flag = True use_shell_flag = config.get("settings.use_shell", use_shell_flag) if kwargs.get("do_not_use_shell", False): use_shell_flag = False if use_shell_flag: log.info("Using shell=True when running external commands...") with interface: with gr.Tab("Dreambooth"): ( train_data_dir_input, reg_data_dir_input, output_dir_input, logging_dir_input, ) = dreambooth_tab( headless=headless, config=config, use_shell_flag=use_shell_flag ) with gr.Tab("LoRA"): lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) with gr.Tab("Textual Inversion"): ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) with gr.Tab("Finetuning"): finetune_tab( headless=headless, config=config, use_shell_flag=use_shell_flag ) with gr.Tab("Utilities"): utilities_tab( train_data_dir_input=train_data_dir_input, reg_data_dir_input=reg_data_dir_input, output_dir_input=output_dir_input, logging_dir_input=logging_dir_input, headless=headless, config=config, ) with gr.Tab("LoRA"): _ = LoRATools(headless=headless) with gr.Tab("About"): gr.Markdown(f"kohya_ss GUI release {release}") with gr.Tab("README"): gr.Markdown(README) htmlStr = f"""
{release}
""" gr.HTML(htmlStr) # Show the interface launch_kwargs = {} username = kwargs.get("username") password = kwargs.get("password") server_port = kwargs.get("server_port", 0) inbrowser = kwargs.get("inbrowser", False) share = kwargs.get("share", False) do_not_share = kwargs.get("do_not_share", False) server_name = kwargs.get("listen") root_path = kwargs.get("root_path", None) launch_kwargs["server_name"] = server_name if username and password: launch_kwargs["auth"] = (username, password) if server_port > 0: launch_kwargs["server_port"] = server_port if inbrowser: launch_kwargs["inbrowser"] = inbrowser if do_not_share: launch_kwargs["share"] = False else: if share: launch_kwargs["share"] = share if root_path: launch_kwargs["root_path"] = root_path launch_kwargs["debug"] = True interface.launch(**launch_kwargs) if __name__ == "__main__": # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, default="./config.toml", help="Path to the toml config file for interface defaults", ) parser.add_argument("--debug", action="store_true", help="Debug on") parser.add_argument( "--listen", type=str, default="127.0.0.1", help="IP to listen on for connections to Gradio", ) parser.add_argument( "--username", type=str, default="", help="Username for authentication" ) parser.add_argument( "--password", type=str, default="", help="Password for authentication" ) parser.add_argument( "--server_port", type=int, default=0, help="Port to run the server listener on", ) parser.add_argument("--inbrowser", action="store_true", help="Open in browser") parser.add_argument("--share", action="store_true", help="Share the gradio UI") parser.add_argument( "--headless", action="store_true", help="Is the server headless" ) parser.add_argument( "--language", type=str, default=None, help="Set custom language" ) parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment") parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment") parser.add_argument( "--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands" ) parser.add_argument( "--do_not_share", action="store_true", help="Do not share the gradio UI" ) parser.add_argument( "--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss" ) args = parser.parse_args() # Set up logging log = setup_logging(debug=args.debug) UI(**vars(args))