law-bot / app.py
anpigon's picture
refactor: Add input validation to chat_function
1c4aaba
import os
import gradio as gr
from kiwipiepy import Kiwi
from typing import List, Tuple, Generator, Union
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder
from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def kiwi_tokenize(text):
kiwi = Kiwi()
return [token.form for token in kiwi.tokenize(text)]
embeddings = get_embeddings()
retriever = load_retrievers(embeddings)
# μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μ‹λ³„μž, value: μ‚¬μš©μžμ—κ²Œ ν‘œμ‹œν•  λ ˆμ΄λΈ”)
AVAILABLE_MODELS = {
# "gpt_3_5_turbo": "GPT-3.5 Turbo",
"gpt_4o": "GPT-4o",
"gemini_1_5_flash": "Gemini 1.5 Flash",
"claude_3_5_sonnet": "Claude 3.5 Sonnet",
# "llama3_70b": "Llama3 70b",
}
def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
langchain_messages = []
for human, ai in chat_history:
langchain_messages.append(HumanMessage(content=human))
langchain_messages.append(AIMessage(content=ai))
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
prompt = get_prompt().partial(history=langchain_messages)
return (
{
"context": retriever
| RunnableLambda(LongContextReorder().transform_documents),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
)
def get_model_key(label):
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def respond_stream(
message: str, history: List[Tuple[str, str]], model: str
) -> Generator[str, None, None]:
rag_chain = create_rag_chain(history, model)
for chunk in rag_chain.stream(message):
yield chunk
def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
rag_chain = create_rag_chain(history, model)
return rag_chain.invoke(message)
def get_model_key(label: str) -> str:
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def validate_input(message: str) -> bool:
"""μž…λ ₯된 λ©”μ‹œμ§€κ°€ μœ νš¨ν•œμ§€ κ²€μ‚¬ν•©λ‹ˆλ‹€."""
return bool(message.strip())
def chat_function(
message: str, history: List[Tuple[str, str]], model_label: str
) -> Generator[str, None, None]:
if not validate_input(message):
yield "λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•΄μ£Όμ„Έμš”."
return
model_key = get_model_key(model_label)
if STREAMING:
response = ""
for chunk in respond_stream(message, history, model_key):
response += chunk
yield response
else:
response = respond(message, history, model_key)
yield response
with gr.Blocks(
fill_height=True,
) as demo:
gr.Markdown("# λŒ€λ²•μ› νŒλ‘€ 상담 λ„μš°λ―Έ")
gr.Markdown(
"μ•ˆλ…•ν•˜μ„Έμš”! λŒ€λ²•μ› νŒλ‘€μ— κ΄€ν•œ μ§ˆλ¬Έμ— λ‹΅λ³€ν•΄λ“œλ¦¬λŠ” AI 상담 λ„μš°λ―Έμž…λ‹ˆλ‹€. νŒλ‘€ 검색, 해석, 적용 등에 λŒ€ν•΄ κΆκΈˆν•˜μ‹  점이 있으면 μ–Έμ œλ“  λ¬Όμ–΄λ³΄μ„Έμš”."
)
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.values()),
label="λͺ¨λΈ 선택",
value=list(AVAILABLE_MODELS.values())[0],
)
chatbot = gr.ChatInterface(
fn=chat_function,
autofocus=True,
fill_height=True,
multimodal=False,
examples=[
[
"쀑고차 거래λ₯Ό ν–ˆλŠ”λ° λΆˆλŸ‰μœΌλ‘œ μ°¨ μˆ˜λ¦¬μ— 500λ§Œμ›μ΄ λ“€μ—ˆμŠ΅λ‹ˆλ‹€. νŒλ§€μžμ—κ²Œ 법적 μ±…μž„μ„ 물을 수 μžˆλ‚˜μš”? λΉ„μŠ·ν•œ νŒλ‘€λ₯Ό μ†Œκ°œν•΄μ£Όμ„Έμš”.",
"GPT-4o",
],
[
"μ•½ 2천 ν‰μ˜ 농지λ₯Ό κ΅¬λ§€ν–ˆλŠ”λ°, μ•Œκ³  λ³΄λ‹ˆ 주택을 지을 수 μ—†λŠ” λ•…μ΄μ—ˆμŠ΅λ‹ˆλ‹€. 이와 μœ μ‚¬ν•œ 뢀동산 사기 κ΄€λ ¨ νŒλ‘€λ₯Ό μ•Œλ €μ£Όμ„Έμš”.",
"GPT-4o",
],
[
"지인이 μž₯λ‚œμœΌλ‘œ νœ˜λ‘λ₯Έ 칼에 νŒ”μ΄ 20cm κ°€λŸ‰ μ°”λ ΈμŠ΅λ‹ˆλ‹€. μž₯λ‚œμ΄λΌκ³  μ£Όμž₯ν•˜λŠ”λ°, 이와 μœ μ‚¬ν•œ 상해 κ΄€λ ¨ νŒλ‘€λ₯Ό μ•Œλ €μ£Όμ„Έμš”.",
"GPT-4o",
],
],
additional_inputs=[model_dropdown],
)
if __name__ == "__main__":
demo.launch()