Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
import sys | |
from modules.ffmpeg_env import setup_ffmpeg_path | |
try: | |
setup_ffmpeg_path() | |
# NOTE: 因为 logger 都是在模块中初始化,所以这个 config 必须在最前面 | |
logging.basicConfig( | |
level=os.getenv("LOG_LEVEL", "INFO"), | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
except BaseException: | |
pass | |
import argparse | |
from modules import config | |
from modules.api.api_setup import process_api_args, setup_api_args | |
from modules.api.app_config import app_description, app_title, app_version | |
from modules.gradio_dcls_fix import dcls_patch | |
from modules.models_setup import process_model_args, setup_model_args | |
from modules.utils.env import get_and_update_env | |
from modules.utils.ignore_warn import ignore_useless_warnings | |
from modules.utils.torch_opt import configure_torch_optimizations | |
from modules.webui import webui_config | |
from modules.webui.app import create_interface, webui_init | |
import subprocess | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
dcls_patch() | |
ignore_useless_warnings() | |
def setup_webui_args(parser: argparse.ArgumentParser): | |
parser.add_argument("--server_name", type=str, help="server name") | |
parser.add_argument("--server_port", type=int, help="server port") | |
parser.add_argument( | |
"--share", action="store_true", help="share the gradio interface" | |
) | |
parser.add_argument("--debug", action="store_true", help="enable debug mode") | |
parser.add_argument("--auth", type=str, help="username:password for authentication") | |
parser.add_argument( | |
"--tts_max_len", | |
type=int, | |
help="Max length of text for TTS", | |
) | |
parser.add_argument( | |
"--ssml_max_len", | |
type=int, | |
help="Max length of text for SSML", | |
) | |
parser.add_argument( | |
"--max_batch_size", | |
type=int, | |
help="Max batch size for TTS", | |
) | |
# webui_Experimental | |
parser.add_argument( | |
"--webui_experimental", | |
action="store_true", | |
help="Enable webui_experimental features", | |
) | |
parser.add_argument( | |
"--language", | |
type=str, | |
help="Set the default language for the webui", | |
) | |
parser.add_argument( | |
"--api", | |
action="store_true", | |
help="use api=True to launch the API together with the webui (run launch.py for only API server)", | |
) | |
def process_webui_args(args): | |
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str) | |
server_port = get_and_update_env(args, "server_port", 7860, int) | |
share = get_and_update_env(args, "share", False, bool) | |
debug = get_and_update_env(args, "debug", False, bool) | |
auth = get_and_update_env(args, "auth", None, str) | |
language = get_and_update_env(args, "language", "zh-CN", str) | |
api = get_and_update_env(args, "api", False, bool) | |
webui_config.experimental = get_and_update_env( | |
args, "webui_experimental", False, bool | |
) | |
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int) | |
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int) | |
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int) | |
webui_config.experimental = get_and_update_env( | |
args, "webui_experimental", False, bool | |
) | |
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int) | |
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int) | |
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int) | |
configure_torch_optimizations() | |
webui_init() | |
demo = create_interface() | |
if auth: | |
auth = tuple(auth.split(":")) | |
app, local_url, share_url = demo.queue().launch( | |
server_name=server_name, | |
server_port=server_port, | |
share=share, | |
debug=debug, | |
auth=auth, | |
show_api=False, | |
prevent_thread_lock=True, | |
inbrowser=sys.platform == "win32", | |
app_kwargs={ | |
"title": app_title, | |
"description": app_description, | |
"version": app_version, | |
"redoc_url": ( | |
None | |
if api is False | |
else None if config.runtime_env_vars.no_docs else "/redoc" | |
), | |
"docs_url": ( | |
None | |
if api is False | |
else None if config.runtime_env_vars.no_docs else "/docs" | |
), | |
}, | |
) | |
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for | |
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the | |
# running web ui and do whatever the attacker wants, including installing an extension and | |
# running its code. We disable this here. Suggested by RyotaK. | |
app.user_middleware = [ | |
x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware" | |
] | |
if api: | |
process_api_args(args, app) | |
demo.block_thread() | |
if __name__ == "__main__": | |
import dotenv | |
dotenv.load_dotenv( | |
dotenv_path=os.getenv("ENV_FILE", ".env.webui"), | |
) | |
parser = argparse.ArgumentParser(description="Gradio App") | |
setup_webui_args(parser) | |
setup_model_args(parser) | |
setup_api_args(parser) | |
args = parser.parse_args() | |
process_model_args(args) | |
process_webui_args(args) | |