|
""" |
|
This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import inspect |
|
from typing import AsyncGenerator, Callable, Literal, Union, cast |
|
|
|
import anyio |
|
from gradio_client.documentation import document |
|
|
|
from gradio.blocks import Blocks |
|
from gradio.components import ( |
|
Button, |
|
Chatbot, |
|
Component, |
|
Markdown, |
|
MultimodalTextbox, |
|
State, |
|
Textbox, |
|
get_component_instance, |
|
Dataset, |
|
) |
|
from gradio.events import Dependency, on |
|
from gradio.helpers import special_args |
|
from gradio.layouts import Accordion, Group, Row |
|
from gradio.routes import Request |
|
from gradio.themes import ThemeClass as Theme |
|
from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda |
|
|
|
|
|
@document() |
|
class ChatInterface(Blocks): |
|
""" |
|
ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create |
|
a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which |
|
takes a function that governs the response of the chatbot based on the user input and chat history. Additional |
|
parameters can be used to control the appearance and behavior of the demo. |
|
|
|
Example: |
|
import gradio as gr |
|
|
|
def echo(message, history): |
|
return message |
|
|
|
demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot") |
|
demo.launch() |
|
Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo |
|
Guides: creating-a-chatbot-fast, sharing-your-app |
|
""" |
|
|
|
def __init__( |
|
self, |
|
fn: Callable, |
|
post_fn: Callable, |
|
pre_fn: Callable, |
|
chatbot: Chatbot, |
|
*, |
|
show_stop_button=True, |
|
post_fn_kwargs: dict = None, |
|
pre_fn_kwargs: dict = None, |
|
multimodal: bool = False, |
|
textbox: Textbox | MultimodalTextbox | None = None, |
|
additional_inputs: str | Component | list[str | Component] | None = None, |
|
additional_inputs_accordion_name: str | None = None, |
|
additional_inputs_accordion: str | Accordion | None = None, |
|
examples: Dataset = None, |
|
title: str | None = None, |
|
description: str | None = None, |
|
theme: Theme | str | None = None, |
|
css: str | None = None, |
|
js: str | None = None, |
|
head: str | None = None, |
|
analytics_enabled: bool | None = None, |
|
submit_btn: str | None | Button = "Submit", |
|
stop_btn: str | None | Button = "Stop", |
|
retry_btn: str | None | Button = "🔄 Retry", |
|
undo_btn: str | None | Button = "↩️ Undo", |
|
clear_btn: str | None | Button = "🗑️ Clear", |
|
autofocus: bool = True, |
|
concurrency_limit: int | None | Literal["default"] = "default", |
|
fill_height: bool = True, |
|
delete_cache: tuple[int, int] | None = None, |
|
): |
|
super().__init__( |
|
analytics_enabled=analytics_enabled, |
|
mode="chat_interface", |
|
css=css, |
|
title=title or "Gradio", |
|
theme=theme, |
|
js=js, |
|
head=head, |
|
fill_height=fill_height, |
|
delete_cache=delete_cache, |
|
) |
|
|
|
if post_fn_kwargs is None: |
|
post_fn_kwargs = [] |
|
|
|
self.post_fn = post_fn |
|
self.post_fn_kwargs = post_fn_kwargs |
|
|
|
self.pre_fn = pre_fn |
|
self.pre_fn_kwargs = pre_fn_kwargs |
|
|
|
self.show_stop_button = show_stop_button |
|
|
|
self.interrupter = State(None) |
|
|
|
self.multimodal = multimodal |
|
self.concurrency_limit = concurrency_limit |
|
self.fn = fn |
|
self.is_async = inspect.iscoroutinefunction( |
|
self.fn |
|
) or inspect.isasyncgenfunction(self.fn) |
|
self.is_generator = inspect.isgeneratorfunction( |
|
self.fn |
|
) or inspect.isasyncgenfunction(self.fn) |
|
|
|
if additional_inputs: |
|
if not isinstance(additional_inputs, list): |
|
additional_inputs = [additional_inputs] |
|
self.additional_inputs = [ |
|
get_component_instance(i) |
|
for i in additional_inputs |
|
] |
|
else: |
|
self.additional_inputs = [] |
|
if additional_inputs_accordion_name is not None: |
|
print( |
|
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead." |
|
) |
|
self.additional_inputs_accordion_params = { |
|
"label": additional_inputs_accordion_name |
|
} |
|
if additional_inputs_accordion is None: |
|
self.additional_inputs_accordion_params = { |
|
"label": "Additional Inputs", |
|
"open": False, |
|
} |
|
elif isinstance(additional_inputs_accordion, str): |
|
self.additional_inputs_accordion_params = { |
|
"label": additional_inputs_accordion |
|
} |
|
elif isinstance(additional_inputs_accordion, Accordion): |
|
self.additional_inputs_accordion_params = ( |
|
additional_inputs_accordion.recover_kwargs( |
|
additional_inputs_accordion.get_config() |
|
) |
|
) |
|
else: |
|
raise ValueError( |
|
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}" |
|
) |
|
|
|
with self: |
|
if title: |
|
Markdown( |
|
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>" |
|
) |
|
if description: |
|
Markdown(description) |
|
|
|
self.chatbot = chatbot.render() |
|
|
|
self.buttons = [retry_btn, undo_btn, clear_btn] |
|
|
|
with Group(): |
|
with Row(): |
|
if textbox: |
|
if self.multimodal: |
|
submit_btn = None |
|
else: |
|
textbox.container = False |
|
textbox.show_label = False |
|
textbox_ = textbox.render() |
|
if not isinstance(textbox_, (Textbox, MultimodalTextbox)): |
|
raise TypeError( |
|
f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}" |
|
) |
|
self.textbox = textbox_ |
|
elif self.multimodal: |
|
submit_btn = None |
|
self.textbox = MultimodalTextbox( |
|
show_label=False, |
|
label="Message", |
|
placeholder="Type a message...", |
|
scale=7, |
|
autofocus=autofocus, |
|
) |
|
else: |
|
self.textbox = Textbox( |
|
container=False, |
|
show_label=False, |
|
label="Message", |
|
placeholder="Type a message...", |
|
scale=7, |
|
autofocus=autofocus, |
|
) |
|
if submit_btn is not None and not multimodal: |
|
if isinstance(submit_btn, Button): |
|
submit_btn.render() |
|
elif isinstance(submit_btn, str): |
|
submit_btn = Button( |
|
submit_btn, |
|
variant="primary", |
|
scale=1, |
|
min_width=150, |
|
) |
|
else: |
|
raise ValueError( |
|
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}" |
|
) |
|
if stop_btn is not None: |
|
if isinstance(stop_btn, Button): |
|
stop_btn.visible = False |
|
stop_btn.render() |
|
elif isinstance(stop_btn, str): |
|
stop_btn = Button( |
|
stop_btn, |
|
variant="stop", |
|
visible=False, |
|
scale=1, |
|
min_width=150, |
|
) |
|
else: |
|
raise ValueError( |
|
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}" |
|
) |
|
self.buttons.extend([submit_btn, stop_btn]) |
|
|
|
self.fake_api_btn = Button("Fake API", visible=False) |
|
self.fake_response_textbox = Textbox(label="Response", visible=False) |
|
( |
|
self.retry_btn, |
|
self.undo_btn, |
|
self.clear_btn, |
|
self.submit_btn, |
|
self.stop_btn, |
|
) = self.buttons |
|
|
|
any_unrendered_inputs = any( |
|
not inp.is_rendered for inp in self.additional_inputs |
|
) |
|
if self.additional_inputs and any_unrendered_inputs: |
|
with Accordion(**self.additional_inputs_accordion_params): |
|
for input_component in self.additional_inputs: |
|
if not input_component.is_rendered: |
|
input_component.render() |
|
|
|
self.saved_input = State() |
|
self.chatbot_state = ( |
|
State(self.chatbot.value) if self.chatbot.value else State([]) |
|
) |
|
|
|
self._setup_events() |
|
self._setup_api() |
|
|
|
if examples: |
|
examples.click(lambda x: x[0], inputs=[examples], outputs=self.textbox, show_progress=False, queue=False) |
|
|
|
def _setup_events(self) -> None: |
|
submit_fn = self._stream_fn if self.is_generator else self._submit_fn |
|
submit_triggers = ( |
|
[self.textbox.submit, self.submit_btn.click] |
|
if self.submit_btn |
|
else [self.textbox.submit] |
|
) |
|
submit_event = ( |
|
on( |
|
submit_triggers, |
|
self._clear_and_save_textbox, |
|
[self.textbox], |
|
[self.textbox, self.saved_input], |
|
show_api=False, |
|
queue=False, |
|
) |
|
.then( |
|
self.pre_fn, |
|
**self.pre_fn_kwargs, |
|
show_api=False, |
|
queue=False, |
|
) |
|
.then( |
|
self._display_input, |
|
[self.saved_input, self.chatbot_state], |
|
[self.chatbot, self.chatbot_state], |
|
show_api=False, |
|
queue=False, |
|
) |
|
.then( |
|
submit_fn, |
|
[self.saved_input, self.chatbot_state] + self.additional_inputs, |
|
[self.chatbot, self.chatbot_state, self.interrupter], |
|
show_api=False, |
|
concurrency_limit=cast( |
|
Union[int, Literal["default"], None], self.concurrency_limit |
|
), |
|
).then( |
|
self.post_fn, |
|
**self.post_fn_kwargs, |
|
show_api=False, |
|
concurrency_limit=cast( |
|
Union[int, Literal["default"], None], self.concurrency_limit |
|
), |
|
) |
|
) |
|
self._setup_stop_events(submit_triggers, submit_event) |
|
|
|
if self.retry_btn: |
|
retry_event = ( |
|
self.retry_btn.click( |
|
self._delete_prev_fn, |
|
[self.saved_input, self.chatbot_state], |
|
[self.chatbot, self.saved_input, self.chatbot_state], |
|
show_api=False, |
|
queue=False, |
|
) |
|
.then( |
|
self.pre_fn, |
|
**self.pre_fn_kwargs, |
|
show_api=False, |
|
queue=False, |
|
) |
|
.then( |
|
self._display_input, |
|
[self.saved_input, self.chatbot_state], |
|
[self.chatbot, self.chatbot_state], |
|
show_api=False, |
|
queue=False, |
|
) |
|
.then( |
|
submit_fn, |
|
[self.saved_input, self.chatbot_state] + self.additional_inputs, |
|
[self.chatbot, self.chatbot_state], |
|
show_api=False, |
|
concurrency_limit=cast( |
|
Union[int, Literal["default"], None], self.concurrency_limit |
|
), |
|
).then( |
|
self.post_fn, |
|
**self.post_fn_kwargs, |
|
show_api=False, |
|
concurrency_limit=cast( |
|
Union[int, Literal["default"], None], self.concurrency_limit |
|
), |
|
) |
|
) |
|
self._setup_stop_events([self.retry_btn.click], retry_event) |
|
|
|
if self.undo_btn: |
|
self.undo_btn.click( |
|
self._delete_prev_fn, |
|
[self.saved_input, self.chatbot_state], |
|
[self.chatbot, self.saved_input, self.chatbot_state], |
|
show_api=False, |
|
queue=False, |
|
).then( |
|
self.pre_fn, |
|
**self.pre_fn_kwargs, |
|
show_api=False, |
|
queue=False, |
|
).then( |
|
async_lambda(lambda x: x), |
|
[self.saved_input], |
|
[self.textbox], |
|
show_api=False, |
|
queue=False, |
|
).then( |
|
self.post_fn, |
|
**self.post_fn_kwargs, |
|
show_api=False, |
|
concurrency_limit=cast( |
|
Union[int, Literal["default"], None], self.concurrency_limit |
|
), |
|
) |
|
|
|
if self.clear_btn: |
|
self.clear_btn.click( |
|
async_lambda(lambda: ([], [], None)), |
|
None, |
|
[self.chatbot, self.chatbot_state, self.saved_input], |
|
queue=False, |
|
show_api=False, |
|
).then( |
|
self.pre_fn, |
|
**self.pre_fn_kwargs, |
|
show_api=False, |
|
queue=False, |
|
).then( |
|
self.post_fn, |
|
**self.post_fn_kwargs, |
|
show_api=False, |
|
concurrency_limit=cast( |
|
Union[int, Literal["default"], None], self.concurrency_limit |
|
), |
|
) |
|
|
|
def _setup_stop_events( |
|
self, event_triggers: list[Callable], event_to_cancel: Dependency |
|
) -> None: |
|
def perform_interrupt(ipc): |
|
if ipc is not None: |
|
ipc() |
|
return |
|
|
|
if self.stop_btn and self.is_generator: |
|
if self.submit_btn: |
|
for event_trigger in event_triggers: |
|
event_trigger( |
|
async_lambda( |
|
lambda: ( |
|
Button(visible=False), |
|
Button(visible=self.show_stop_button), |
|
) |
|
), |
|
None, |
|
[self.submit_btn, self.stop_btn], |
|
show_api=False, |
|
queue=False, |
|
) |
|
event_to_cancel.then( |
|
async_lambda(lambda: (Button(visible=True), Button(visible=False))), |
|
None, |
|
[self.submit_btn, self.stop_btn], |
|
show_api=False, |
|
queue=False, |
|
) |
|
else: |
|
for event_trigger in event_triggers: |
|
event_trigger( |
|
async_lambda(lambda: Button(visible=self.show_stop_button)), |
|
None, |
|
[self.stop_btn], |
|
show_api=False, |
|
queue=False, |
|
) |
|
event_to_cancel.then( |
|
async_lambda(lambda: Button(visible=False)), |
|
None, |
|
[self.stop_btn], |
|
show_api=False, |
|
queue=False, |
|
) |
|
self.stop_btn.click( |
|
fn=perform_interrupt, |
|
inputs=[self.interrupter], |
|
cancels=event_to_cancel, |
|
show_api=False, |
|
) |
|
|
|
def _setup_api(self) -> None: |
|
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn |
|
|
|
self.fake_api_btn.click( |
|
api_fn, |
|
[self.textbox, self.chatbot_state] + self.additional_inputs, |
|
[self.textbox, self.chatbot_state], |
|
api_name="chat", |
|
concurrency_limit=cast( |
|
Union[int, Literal["default"], None], self.concurrency_limit |
|
), |
|
) |
|
|
|
def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]: |
|
if self.multimodal: |
|
return {"text": "", "files": []}, message |
|
else: |
|
return "", message |
|
|
|
def _append_multimodal_history( |
|
self, |
|
message: dict[str, list], |
|
response: str | None, |
|
history: list[list[str | tuple | None]], |
|
): |
|
for x in message["files"]: |
|
history.append([(x,), None]) |
|
if message["text"] is None or not isinstance(message["text"], str): |
|
return |
|
elif message["text"] == "" and message["files"] != []: |
|
history.append([None, response]) |
|
else: |
|
history.append([message["text"], response]) |
|
|
|
async def _display_input( |
|
self, message: str | dict[str, list], history: list[list[str | tuple | None]] |
|
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]: |
|
if self.multimodal and isinstance(message, dict): |
|
self._append_multimodal_history(message, None, history) |
|
elif isinstance(message, str): |
|
history.append([message, None]) |
|
return history, history |
|
|
|
async def _submit_fn( |
|
self, |
|
message: str | dict[str, list], |
|
history_with_input: list[list[str | tuple | None]], |
|
request: Request, |
|
*args, |
|
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]: |
|
if self.multimodal and isinstance(message, dict): |
|
remove_input = ( |
|
len(message["files"]) + 1 |
|
if message["text"] is not None |
|
else len(message["files"]) |
|
) |
|
history = history_with_input[:-remove_input] |
|
else: |
|
history = history_with_input[:-1] |
|
inputs, _, _ = special_args( |
|
self.fn, inputs=[message, history, *args], request=request |
|
) |
|
|
|
if self.is_async: |
|
response = await self.fn(*inputs) |
|
else: |
|
response = await anyio.to_thread.run_sync( |
|
self.fn, *inputs, limiter=self.limiter |
|
) |
|
|
|
if self.multimodal and isinstance(message, dict): |
|
self._append_multimodal_history(message, response, history) |
|
elif isinstance(message, str): |
|
history.append([message, response]) |
|
return history, history |
|
|
|
async def _stream_fn( |
|
self, |
|
message: str | dict[str, list], |
|
history_with_input: list[list[str | tuple | None]], |
|
request: Request, |
|
*args, |
|
) -> AsyncGenerator: |
|
if self.multimodal and isinstance(message, dict): |
|
remove_input = ( |
|
len(message["files"]) + 1 |
|
if message["text"] is not None |
|
else len(message["files"]) |
|
) |
|
history = history_with_input[:-remove_input] |
|
else: |
|
history = history_with_input[:-1] |
|
inputs, _, _ = special_args( |
|
self.fn, inputs=[message, history, *args], request=request |
|
) |
|
|
|
if self.is_async: |
|
generator = self.fn(*inputs) |
|
else: |
|
generator = await anyio.to_thread.run_sync( |
|
self.fn, *inputs, limiter=self.limiter |
|
) |
|
generator = SyncToAsyncIterator(generator, self.limiter) |
|
try: |
|
first_response, first_interrupter = await async_iteration(generator) |
|
if self.multimodal and isinstance(message, dict): |
|
for x in message["files"]: |
|
history.append([(x,), None]) |
|
update = history + [[message["text"], first_response]] |
|
yield update, update |
|
else: |
|
update = history + [[message, first_response]] |
|
yield update, update, first_interrupter |
|
except StopIteration: |
|
if self.multimodal and isinstance(message, dict): |
|
self._append_multimodal_history(message, None, history) |
|
yield history, history |
|
else: |
|
update = history + [[message, None]] |
|
yield update, update, first_interrupter |
|
async for response, interrupter in generator: |
|
if self.multimodal and isinstance(message, dict): |
|
update = history + [[message["text"], response]] |
|
yield update, update |
|
else: |
|
update = history + [[message, response]] |
|
yield update, update, interrupter |
|
|
|
async def _api_submit_fn( |
|
self, message: str, history: list[list[str | None]], request: Request, *args |
|
) -> tuple[str, list[list[str | None]]]: |
|
inputs, _, _ = special_args( |
|
self.fn, inputs=[message, history, *args], request=request |
|
) |
|
|
|
if self.is_async: |
|
response = await self.fn(*inputs) |
|
else: |
|
response = await anyio.to_thread.run_sync( |
|
self.fn, *inputs, limiter=self.limiter |
|
) |
|
history.append([message, response]) |
|
return response, history |
|
|
|
async def _api_stream_fn( |
|
self, message: str, history: list[list[str | None]], request: Request, *args |
|
) -> AsyncGenerator: |
|
inputs, _, _ = special_args( |
|
self.fn, inputs=[message, history, *args], request=request |
|
) |
|
|
|
if self.is_async: |
|
generator = self.fn(*inputs) |
|
else: |
|
generator = await anyio.to_thread.run_sync( |
|
self.fn, *inputs, limiter=self.limiter |
|
) |
|
generator = SyncToAsyncIterator(generator, self.limiter) |
|
try: |
|
first_response = await async_iteration(generator) |
|
yield first_response, history + [[message, first_response]] |
|
except StopIteration: |
|
yield None, history + [[message, None]] |
|
async for response in generator: |
|
yield response, history + [[message, response]] |
|
|
|
async def _delete_prev_fn( |
|
self, |
|
message: str | dict[str, list], |
|
history: list[list[str | tuple | None]], |
|
) -> tuple[ |
|
list[list[str | tuple | None]], |
|
str | dict[str, list], |
|
list[list[str | tuple | None]], |
|
]: |
|
if self.multimodal and isinstance(message, dict): |
|
remove_input = ( |
|
len(message["files"]) + 1 |
|
if message["text"] is not None |
|
else len(message["files"]) |
|
) |
|
history = history[:-remove_input] |
|
else: |
|
while history: |
|
deleted_a, deleted_b = history[-1] |
|
history = history[:-1] |
|
if isinstance(deleted_a, str) and isinstance(deleted_b, str): |
|
break |
|
return history, message or "", history |
|
|