|
import os |
|
import gradio as gr |
|
from kiwipiepy import Kiwi |
|
from typing import List, Tuple, Generator, Union |
|
|
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.messages import HumanMessage, AIMessage |
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda |
|
from langchain_community.document_transformers import LongContextReorder |
|
|
|
from libs.config import STREAMING |
|
from libs.embeddings import get_embeddings |
|
from libs.retrievers import load_retrievers |
|
from libs.llm import get_llm |
|
from libs.prompt import get_prompt |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def kiwi_tokenize(text): |
|
kiwi = Kiwi() |
|
return [token.form for token in kiwi.tokenize(text)] |
|
|
|
|
|
embeddings = get_embeddings() |
|
retriever = load_retrievers(embeddings) |
|
|
|
|
|
|
|
AVAILABLE_MODELS = { |
|
|
|
"gpt_4o": "GPT-4o", |
|
"gemini_1_5_flash": "Gemini 1.5 Flash", |
|
"claude_3_5_sonnet": "Claude 3.5 Sonnet", |
|
|
|
} |
|
|
|
|
|
def create_rag_chain(chat_history: List[Tuple[str, str]], model: str): |
|
langchain_messages = [] |
|
for human, ai in chat_history: |
|
langchain_messages.append(HumanMessage(content=human)) |
|
langchain_messages.append(AIMessage(content=ai)) |
|
|
|
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model}) |
|
prompt = get_prompt().partial(history=langchain_messages) |
|
|
|
return ( |
|
{ |
|
"context": retriever |
|
| RunnableLambda(LongContextReorder().transform_documents), |
|
"question": RunnablePassthrough(), |
|
} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
def get_model_key(label): |
|
return next(key for key, value in AVAILABLE_MODELS.items() if value == label) |
|
|
|
|
|
def respond_stream( |
|
message: str, history: List[Tuple[str, str]], model: str |
|
) -> Generator[str, None, None]: |
|
rag_chain = create_rag_chain(history, model) |
|
for chunk in rag_chain.stream(message): |
|
yield chunk |
|
|
|
|
|
def respond(message: str, history: List[Tuple[str, str]], model: str) -> str: |
|
rag_chain = create_rag_chain(history, model) |
|
return rag_chain.invoke(message) |
|
|
|
|
|
def get_model_key(label: str) -> str: |
|
return next(key for key, value in AVAILABLE_MODELS.items() if value == label) |
|
|
|
|
|
def validate_input(message: str) -> bool: |
|
"""μ
λ ₯λ λ©μμ§κ° μ ν¨νμ§ κ²μ¬ν©λλ€.""" |
|
return bool(message.strip()) |
|
|
|
|
|
def chat_function( |
|
message: str, history: List[Tuple[str, str]], model_label: str |
|
) -> Generator[str, None, None]: |
|
if not validate_input(message): |
|
yield "λ©μμ§λ₯Ό μ
λ ₯ν΄μ£ΌμΈμ." |
|
return |
|
|
|
model_key = get_model_key(model_label) |
|
if STREAMING: |
|
response = "" |
|
for chunk in respond_stream(message, history, model_key): |
|
response += chunk |
|
yield response |
|
else: |
|
response = respond(message, history, model_key) |
|
yield response |
|
|
|
|
|
with gr.Blocks( |
|
fill_height=True, |
|
) as demo: |
|
gr.Markdown("# λλ²μ νλ‘ μλ΄ λμ°λ―Έ") |
|
gr.Markdown( |
|
"μλ
νμΈμ! λλ²μ νλ‘μ κ΄ν μ§λ¬Έμ λ΅λ³ν΄λ리λ AI μλ΄ λμ°λ―Έμ
λλ€. νλ‘ κ²μ, ν΄μ, μ μ© λ±μ λν΄ κΆκΈνμ μ μ΄ μμΌλ©΄ μΈμ λ λ¬Όμ΄λ³΄μΈμ." |
|
) |
|
|
|
model_dropdown = gr.Dropdown( |
|
choices=list(AVAILABLE_MODELS.values()), |
|
label="λͺ¨λΈ μ ν", |
|
value=list(AVAILABLE_MODELS.values())[0], |
|
) |
|
|
|
chatbot = gr.ChatInterface( |
|
fn=chat_function, |
|
autofocus=True, |
|
fill_height=True, |
|
multimodal=False, |
|
examples=[ |
|
[ |
|
"μ€κ³ μ°¨ κ±°λλ₯Ό νλλ° λΆλμΌλ‘ μ°¨ μ리μ 500λ§μμ΄ λ€μμ΅λλ€. ν맀μμκ² λ²μ μ±
μμ λ¬Όμ μ μλμ? λΉμ·ν νλ‘λ₯Ό μκ°ν΄μ£ΌμΈμ.", |
|
"GPT-4o", |
|
], |
|
[ |
|
"μ½ 2μ² νμ λμ§λ₯Ό ꡬ맀νλλ°, μκ³ λ³΄λ μ£Όνμ μ§μ μ μλ λ
μ΄μμ΅λλ€. μ΄μ μ μ¬ν λΆλμ° μ¬κΈ° κ΄λ ¨ νλ‘λ₯Ό μλ €μ£ΌμΈμ.", |
|
"GPT-4o", |
|
], |
|
[ |
|
"μ§μΈμ΄ μ₯λμΌλ‘ νλλ₯Έ μΉΌμ νμ΄ 20cm κ°λ μ°λ Έμ΅λλ€. μ₯λμ΄λΌκ³ μ£Όμ₯νλλ°, μ΄μ μ μ¬ν μν΄ κ΄λ ¨ νλ‘λ₯Ό μλ €μ£ΌμΈμ.", |
|
"GPT-4o", |
|
], |
|
], |
|
additional_inputs=[model_dropdown], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|