my_gradio / gradio /ipython_ext.py
xray918's picture
Upload folder using huggingface_hub
0ad74ed verified
try:
from IPython.core.magic import (
needs_local_scope,
register_cell_magic,
)
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
except ImportError:
pass
import gradio as gr
from gradio.routes import App
from gradio.utils import BaseReloader
class CellIdTracker:
"""Determines the most recently run cell in the notebook.
Needed to keep track of which demo the user is updating.
"""
def __init__(self, ipython):
ipython.events.register("pre_run_cell", self.pre_run_cell)
self.shell = ipython
self.current_cell: str = ""
def pre_run_cell(self, info):
self._current_cell = info.cell_id
class JupyterReloader(BaseReloader):
"""Swap a running blocks class in a notebook with the latest cell contents."""
def __init__(self, ipython) -> None:
super().__init__()
self._cell_tracker = CellIdTracker(ipython)
self._running: dict[str, gr.Blocks] = {}
@property
def current_cell(self):
return self._cell_tracker.current_cell
@property
def running_app(self) -> App:
if not self.running_demo.server:
raise RuntimeError("Server not running")
return self.running_demo.server.running_app # type: ignore
@property
def running_demo(self):
return self._running[self.current_cell]
def demo_tracked(self) -> bool:
return self.current_cell in self._running
def track(self, demo: gr.Blocks):
self._running[self.current_cell] = demo
def load_ipython_extension(ipython):
reloader = JupyterReloader(ipython)
@magic_arguments() # type: ignore
@argument("--demo-name", default="demo", help="Name of gradio blocks instance.") # type: ignore
@argument( # type: ignore
"--share",
default=False,
const=True,
nargs="?",
help="Whether to launch with sharing. Will slow down reloading.",
)
@register_cell_magic # type: ignore
@needs_local_scope # type: ignore
def blocks(line, cell, local_ns):
"""Launch a demo defined in a cell in reload mode."""
args = parse_argstring(blocks, line) # type: ignore
exec(cell, None, local_ns)
demo: gr.Blocks = local_ns[args.demo_name]
if not reloader.demo_tracked():
demo.launch(share=args.share)
reloader.track(demo)
else:
reloader.swap_blocks(demo)
return reloader.running_demo.artifact