Spaces:
Running
Running
# -------------------------------------- | |
# Chat with Documents | |
# キカガク 2023.4月期 最終成果アプリ | |
# Copyright. cawacci | |
# -------------------------------------- | |
# -------------------------------------- | |
# Libraries | |
# -------------------------------------- | |
import os | |
import time | |
import gc # メモリ解放 | |
import re # 正規表現で文章をクリーンアップ | |
# HuggingFace | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# OpenAI | |
import openai | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.chat_models import ChatOpenAI | |
# LangChain | |
import langchain | |
from langchain.llms import HuggingFacePipeline | |
from transformers import pipeline | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import LLMChain, VectorDBQA | |
from langchain.vectorstores import Chroma | |
from langchain import PromptTemplate, ConversationChain | |
from langchain.chains.question_answering import load_qa_chain # QA Chat | |
from langchain.document_loaders import SeleniumURLLoader # URL取得 | |
from langchain.docstore.document import Document # テキストをドキュメント化 | |
from langchain.memory import ConversationSummaryBufferMemory # チャット履歴 | |
from typing import Any | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.tools import DuckDuckGoSearchRun | |
# Gradio | |
import gradio as gr | |
from pypdf import PdfReader | |
import requests # DeepL API request | |
# Mecab | |
import MeCab | |
# -------------------------------------- | |
# ユーザ別セッションの変数値を記録するクラス | |
# (参考)https://blog.shikoan.com/gradio-state/ | |
# -------------------------------------- | |
class SessionState: | |
def __init__(self): | |
# Hugging Face | |
self.tokenizer = None | |
self.pipe = None | |
self.model = None | |
# LangChain | |
self.llm = None | |
self.embeddings = None | |
self.current_model = "" | |
self.current_embedding = "" | |
self.db = None # Vector DB | |
self.memory = None # Langchain Chat Memory | |
self.conversation_chain = None # ConversationChain | |
self.query_generator = None # Query Refiner with Chat history | |
self.qa_chain = None # load_qa_chain | |
self.web_summary_chain = None # Summarize web search result | |
self.embedded_urls = [] | |
self.similarity_search_k = None # No. of similarity search documents to find. | |
self.summarization_mode = None # Stuff / Map Reduce / Refine | |
# Apps | |
self.dialogue = [] # Recent Chat History for display | |
# -------------------------------------- | |
# Empty Cache | |
# -------------------------------------- | |
def cache_clear(self): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() # GPU Memory Clear | |
gc.collect() # CPU Memory Clear | |
# -------------------------------------- | |
# Clear Models (llm: llm model, embd: embeddings, db: vectordb) | |
# -------------------------------------- | |
def clear_memory(self, llm=False, embd=False, db=False): | |
# DB | |
if db and self.db: | |
self.db.delete_collection() | |
self.db = None | |
self.embedded_urls = [] | |
# Embeddings model | |
if llm or embd: | |
self.embeddings = None | |
self.current_embedding = "" | |
self.qa_chain = None | |
# LLM model | |
if llm: | |
self.llm = None | |
self.pipe = None | |
self.model = None | |
self.current_model = "" | |
self.tokenizer = None | |
self.memory = None | |
self.chat_history = [] # ←必要性を要検証 | |
self.cache_clear() | |
# -------------------------------------- | |
# 自作TextSplitter(テキストをLLMのトークン数内に分割) | |
# (参考)https://www.sato-susumu.com/entry/2023/04/30/131338 | |
# → 「!」、「?」、「)」、「.」、「!」、「?」、「,」などを追加 | |
# -------------------------------------- | |
class JPTextSplitter(RecursiveCharacterTextSplitter): | |
def __init__(self, **kwargs: Any): | |
separators = ["\n\n", "\n", "。", "!", "?", ")","、", ".", "!", "?", ",", " ", ""] | |
super().__init__(separators=separators, **kwargs) | |
# チャンクの分割 | |
chunk_size = 512 | |
chunk_overlap = 35 | |
text_splitter = JPTextSplitter( | |
chunk_size = chunk_size, # チャンクの最大文字数 | |
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数 | |
) | |
# -------------------------------------- | |
# 文中から人名を抽出 | |
# -------------------------------------- | |
def name_detector(text: str) -> list: | |
mecab = MeCab.Tagger() | |
mecab.parse('') # ←バグ対応 | |
node = mecab.parseToNode(text).next | |
names = [] | |
while node: | |
if node.feature.split(',')[3] == "姓": | |
if node.next and node.next.feature.split(',')[3] == "名": | |
names.append(str(node.surface) + str(node.next.surface)) | |
else: | |
names.append(node.surface) | |
if node.feature.split(',')[3] == "名": | |
if node.prev and node.prev.feature.split(',')[3] == "姓": | |
pass | |
else: | |
names.append(str(node.surface)) | |
node = node.next | |
names = list(set(names)) | |
return names | |
# -------------------------------------- | |
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時) | |
# -------------------------------------- | |
DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate" | |
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") | |
def deepl_memory(ss: SessionState) -> (SessionState): | |
if ss.current_model == "gpt-3.5-turbo": | |
# メモリから会話履歴を取得 | |
user_message = ss.memory.chat_memory.messages[-2].content | |
ai_message = ss.memory.chat_memory.messages[-1].content | |
text = [user_message, ai_message] | |
# DeepL設定 | |
params = { | |
"auth_key": DEEPL_API_KEY, | |
"text": text, | |
"target_lang": "EN", | |
"source_lang": "JA", | |
"tag_handling": "xml", | |
"igonere_tags": "x", | |
} | |
request = requests.post(DEEPL_API_ENDPOINT, data=params) | |
request.raise_for_status() # 応答のステータスコードがエラーの場合は例外を発生させます。 | |
response = request.json() | |
# JSONから翻訳文を取得 | |
user_message = response["translations"][0]["text"] | |
ai_message = response["translations"][1]["text"] | |
# memoryの最後の会話を削除し、翻訳文を追加 | |
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2] | |
ss.memory.chat_memory.add_user_message(user_message) | |
ss.memory.chat_memory.add_ai_message(ai_message) | |
return ss | |
# -------------------------------------- | |
# DuckDuckGo Web検索結果を入力プロンプトに追加 | |
# -------------------------------------- | |
# DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate" | |
# DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") | |
def web_search(ss: SessionState, query) -> (SessionState, str): | |
search = DuckDuckGoSearchRun(verbose=True) | |
for i in range(3): | |
web_result = search(query) | |
# 人名の抽出 | |
names = [] | |
names.extend(name_detector(query)) | |
names.extend(name_detector(web_result)) | |
if len(names)==0: | |
names = "" | |
elif len(names)==1: | |
names = names[0] | |
else: | |
names = ", ".join(names) | |
if ss.current_model == "gpt-3.5-turbo": | |
text = [query, web_result] | |
params = { | |
"auth_key": DEEPL_API_KEY, | |
"text": text, | |
"target_lang": "EN", | |
"source_lang": "JA", | |
"tag_handling": "xml", | |
"ignore_tags": "x", | |
} | |
request = requests.post(DEEPL_API_ENDPOINT, data=params) | |
response = request.json() | |
query = response["translations"][0]["text"] | |
web_result = response["translations"][1]["text"] | |
web_result = ss.web_summary_chain({'query': query, 'context': web_result})['text'] | |
if web_result != "NO INFO": | |
break | |
if names != "": | |
web_query = f""" | |
{query} | |
Use the following Suggested Answer Source as a reliable reference to answer the question above in Japanese. When translating names of people, refer to Names as a translation guide. | |
Suggested Answer Source: {web_result} | |
Names: {names} | |
""".strip() | |
else: | |
web_query = query + "\nUse the following Suggested Answer Source as a reliable reference to answer the question above in the Japanese.\n===\nSuggested Answer Source: " + web_result + "\n" | |
return ss, web_query | |
# -------------------------------------- | |
# LangChain カスタムプロンプト各種 | |
# llama tokenizer | |
# https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/ | |
# OpenAI tokenizer | |
# https://platform.openai.com/tokenizer | |
# -------------------------------------- | |
# -------------------------------------- | |
# Conversation Chain Template | |
# -------------------------------------- | |
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162 | |
sys_chat_message = """ | |
You are an outstanding AI concierge. Understand the intent of the customer's questions based on | |
the conversation history. Then, answer them with many specific and detailed information in Japanese. | |
If you do not know the answer to a question, do make up an answer and says | |
"誠に申し訳ございませんが、その点についてはわかりかねます". | |
""".replace("\n", "") | |
chat_common_format = """ | |
=== | |
Question: {query} | |
=== | |
Conversation History: | |
{chat_history} | |
=== | |
日本語の回答: """ | |
chat_template_std = f"{sys_chat_message}{chat_common_format}" | |
chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common_format}[/INST]" | |
# -------------------------------------- | |
# QA Chain Template (Stuff) | |
# -------------------------------------- | |
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225 | |
sys_qa_message = """ | |
You are an AI concierge who carefully answers questions from customers based on references. | |
Understand the intent of the customer's questions based on the conversation history. Then, give | |
a specific answer in Japanese using sentences extracted from the following references. If you do | |
not know the answer, do not make up an answer and reply, "誠に申し訳ございませんが、その点についてはわかりかねます". | |
""".replace("\n", "") | |
qa_common_format = """ | |
=== | |
Question: {query} | |
References: {context} | |
=== | |
Conversation History: | |
{chat_history} | |
=== | |
日本語の回答: """ | |
qa_template_std = f"{sys_qa_message}{qa_common_format}" | |
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]" | |
# -------------------------------------- | |
# QA Chain Template (Map Reduce) | |
# -------------------------------------- | |
# 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト | |
query_generator_message = """ | |
Referring to the "Conversation History", reformat the user's "Additional Question" | |
to a specific question by filling in the missing subject, verb, objects, complements, | |
and other necessary information to get a better search result. Answer in Japanese. | |
""".replace("\n", "") | |
query_generator_common_format = """ | |
=== | |
[Conversation History] | |
{chat_history} | |
[Additional Question] {query} | |
明確な日本語の質問文: """ | |
query_generator_template_std = f"{query_generator_message}{query_generator_common_format}" | |
query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]" | |
# 2. 生成された質問文を用いて、参考文献を要約するchain のプロンプト | |
# question_prompt_message = """ | |
# From the following references, extract key information relevant to the question | |
# and summarize it in a natural English sentence with clear subject, verb, object, | |
# and complement. If there is no information in the reference that answers the question, | |
# do not summarize and simply answer "NO INFO" | |
# """.replace("\n", "") | |
question_prompt_message = """ | |
1. Determine if any of the following references provide information that answers the Question, and if there is no information, answer "NO INFO" and stop. | |
2. From the following references, extract key information relevant to the question and summarize it in a natural English sentence with clear subject, verb, object, and complement. | |
""".strip() | |
question_prompt_common_format = """ | |
=== | |
[Question] {query} | |
[references] {context} | |
[Answer]""" | |
question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}" | |
question_prompt_template_llama2 = f"<s>[INST] <<SYS>>{question_prompt_message}<</SYS>>{question_prompt_common_format}[/INST]" | |
# 3. 生成された質問文とベクターデータベースの要約をもとに、回答を行うchain のプロンプト | |
combine_prompt_message = """ | |
You are an AI concierge who carefully answers questions from customers based on references. | |
Provide a specific answer in Japanese using sentences extracted from the following references. | |
If you do not know the answer, do not make up an answer and reply, | |
"誠に申し訳ございませんが、その点についてはわかりかねます". | |
""".replace("\n", "") | |
combine_prompt_common_format = """ | |
=== | |
Question: {query} | |
Reference: {summaries} | |
日本語の回答: """ | |
combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}" | |
combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]" | |
# -------------------------------------- | |
# ConversationSummaryBufferMemoryの要約プロンプト | |
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49 | |
# -------------------------------------- | |
# Tokens: OpenAI 212/ Llama 214 <- In Japanese: Tokens: OpenAI 397/ Llama 297 | |
conversation_summary_template = """ | |
Using the example as a guide, compose a summary in English that gives an overview of the conversation by summarizing the "current summary" and the "new conversation". | |
=== | |
Example | |
[Current Summary] Customer asks AI what it thinks about Artificial Intelligence, AI says Artificial Intelligence is a good tool. | |
[New Conversation] | |
Human: なぜ人工知能が良いツールだと思いますか? | |
AI: 人工知能は「人間の可能性を最大限に引き出すことを助ける」からです。 | |
[New Summary] Customer asks what you think about Artificial Intelligence, and AI responds that it is a good force that helps humans reach their full potential. | |
=== | |
[Current Summary] {summary} | |
[New Conversation] | |
{new_lines} | |
[New Summary] | |
""".strip() | |
# モデル読み込み | |
def load_models( | |
ss: SessionState, | |
model_id: str, | |
embedding_id: str, | |
openai_api_key: str, | |
load_in_8bit: bool, | |
verbose: bool, | |
temperature: float, | |
similarity_search_k: int, | |
summarization_mode: str, | |
min_length: int, | |
max_new_tokens: int, | |
top_k: int, | |
top_p: float, | |
repetition_penalty: float, | |
num_return_sequences: int, | |
) -> (SessionState, str): | |
# -------------------------------------- | |
# 変数の保存 | |
# -------------------------------------- | |
ss.similarity_search_k = similarity_search_k | |
ss.summarization_mode = summarization_mode | |
# -------------------------------------- | |
# OpenAI API KEYの確認 | |
# -------------------------------------- | |
if (model_id == "gpt-3.5-turbo" or embedding_id == "text-embedding-ada-002"): | |
# 前処理 | |
if not os.environ["OPENAI_API_KEY"]: | |
status_message = "❌ OpenAI API KEY を設定してください" | |
return ss, status_message | |
# -------------------------------------- | |
# LLMの設定 | |
# -------------------------------------- | |
# OpenAI Model | |
if model_id == "gpt-3.5-turbo": | |
ss.clear_memory(llm=True, db=True) | |
ss.llm = ChatOpenAI( | |
model_name = model_id, | |
temperature = temperature, | |
verbose = verbose, | |
max_tokens = max_new_tokens, | |
) | |
# Hugging Face GPT Model | |
else: | |
ss.clear_memory(llm=True, db=True) | |
if model_id == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
ss.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) | |
else: | |
ss.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
ss.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
load_in_8bit = load_in_8bit, | |
torch_dtype = torch.float16, | |
device_map = "auto", | |
) | |
ss.pipe = pipeline( | |
"text-generation", | |
model = ss.model, | |
tokenizer = ss.tokenizer, | |
min_length = min_length, | |
max_new_tokens = max_new_tokens, | |
do_sample = True, | |
top_k = top_k, | |
top_p = top_p, | |
repetition_penalty = repetition_penalty, | |
num_return_sequences = num_return_sequences, | |
temperature = temperature, | |
) | |
ss.llm = HuggingFacePipeline(pipeline=ss.pipe) | |
# -------------------------------------- | |
# 埋め込みモデルの設定 | |
# -------------------------------------- | |
if ss.current_embedding == embedding_id: | |
pass | |
else: | |
# Reset embeddings and vectordb | |
ss.clear_memory(embd=True, db=True) | |
if embedding_id == "None": | |
pass | |
# OpenAI | |
elif embedding_id == "text-embedding-ada-002": | |
ss.embeddings = OpenAIEmbeddings() | |
# Hugging Face | |
else: | |
ss.embeddings = HuggingFaceEmbeddings(model_name=embedding_id) | |
# -------------------------------------- | |
# チェーンの設定 | |
#--------------------------------------- | |
ss = set_chains(ss, summarization_mode) | |
# -------------------------------------- | |
# 現在のモデル名を SessionStateオブジェクトに保存 | |
#--------------------------------------- | |
ss.current_model = model_id | |
ss.current_embedding = embedding_id | |
# Status Message | |
status_message = "✅ LLM: " + ss.current_model + ", embeddings: " + ss.current_embedding | |
return ss, status_message | |
# -------------------------------------- | |
# Conversation/QA Chain 呼び出し統合 | |
# -------------------------------------- | |
def set_chains(ss: SessionState, summarization_mode) -> SessionState: | |
# モデルに合わせて chat_template を設定 | |
human_prefix = "Human: " | |
ai_prefix = "AI: " | |
chat_template = chat_template_std | |
qa_template = qa_template_std | |
query_generator_template = query_generator_template_std | |
question_template = question_prompt_template_std | |
combine_template = combine_prompt_template_std | |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
# Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照) | |
chat_template = chat_template.replace("\n", "<NL>") | |
qa_template = qa_template.replace("\n", "<NL>") | |
query_generator_template = query_generator_template_std.replace("\n", "<NL>") | |
question_template = question_prompt_template_std.replace("\n", "<NL>") | |
combine_template = combine_prompt_template_std.replace("\n", "<NL>") | |
human_prefix = "ユーザー: " | |
ai_prefix = "システム: " | |
elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"): | |
# ELYZAモデル向けのテンプレート設定 | |
chat_template = chat_template_llama2 | |
qa_template = qa_template_llama2 | |
query_generator_template = query_generator_template_llama2 | |
question_template = question_prompt_template_llama2 | |
combine_template = combine_prompt_template_llama2 | |
# -------------------------------------- | |
# メモリの設定 | |
# -------------------------------------- | |
if ss.memory is None: | |
conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template) | |
ss.memory = ConversationSummaryBufferMemory( | |
llm = ss.llm, | |
memory_key = "chat_history", | |
input_key = "query", | |
output_key = "output_text", | |
return_messages = False, | |
human_prefix = human_prefix, | |
ai_prefix = ai_prefix, | |
max_token_limit = 1024, | |
prompt = conversation_summary_prompt, | |
) | |
# -------------------------------------- | |
# Conversation/QAチェーンの設定 | |
# -------------------------------------- | |
if ss.query_generator is None: | |
query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"]) | |
ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True) | |
if ss.conversation_chain is None: | |
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template) | |
ss.conversation_chain = ConversationChain( | |
llm = ss.llm, | |
prompt = chat_prompt, | |
# memory = ss.memory, | |
input_key = "query", | |
output_key = "output_text", | |
verbose = True, | |
) | |
if ss.qa_chain is None: | |
if summarization_mode == "stuff": | |
qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template) | |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt) | |
elif summarization_mode == "map_reduce": | |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"]) | |
combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"]) | |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt) | |
if ss.web_summary_chain is None: | |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"]) | |
ss.web_summary_chain = LLMChain(llm=ss.llm, prompt=question_prompt, verbose=True) | |
return ss | |
def initialize_db(ss: SessionState) -> SessionState: | |
# client = chromadb.PersistentClient(path="./db") | |
ss.db = Chroma( | |
collection_name = "user_reference", | |
embedding_function = ss.embeddings, | |
# client = client | |
) | |
return ss | |
def embedding_process(ss: SessionState, ref_documents: Document) -> SessionState: | |
# -------------------------------------- | |
# 文章構成と不要な文字列の削除 | |
# -------------------------------------- | |
for i in range(len(ref_documents)): | |
content = ref_documents[i].page_content.strip() | |
# -------------------------------------- | |
# PDFの場合は読み取りエラー対策で文書修正を強めに実施 | |
# -------------------------------------- | |
if ".pdf" in ref_documents[i].metadata['source']: | |
pdf_replacement_sets = [ | |
('\n ', '**PLACEHOLDER+SPACE**'), | |
('\n\u3000', '**PLACEHOLDER+SPACE**'), | |
('.\n', '。**PLACEHOLDER**'), | |
(',\n', '。**PLACEHOLDER**'), | |
('?\n', '。**PLACEHOLDER**'), | |
('!\n', '。**PLACEHOLDER**'), | |
('!\n', '。**PLACEHOLDER**'), | |
('。\n', '。**PLACEHOLDER**'), | |
('!\n', '!**PLACEHOLDER**'), | |
(')\n', '!**PLACEHOLDER**'), | |
(']\n', '!**PLACEHOLDER**'), | |
('?\n', '?**PLACEHOLDER**'), | |
(')\n', '?**PLACEHOLDER**'), | |
('】\n', '?**PLACEHOLDER**'), | |
] | |
for original, replacement in pdf_replacement_sets: | |
content = content.replace(original, replacement) | |
content = content.replace(" ", "") | |
# -------------------------------------- | |
# 不要文字列・空白の削除 | |
remove_texts = ["\n", "\r", " "] | |
for remove_text in remove_texts: | |
content = content.replace(remove_text, "") | |
# タブや連続空白をシングルスペースに変換 | |
replace_texts = ["\t", "\u3000"] | |
for replace_text in replace_texts: | |
content = content.replace(replace_text, " ") | |
# PDFの正当な改行をもとに戻す。 | |
if ".pdf" in ref_documents[i].metadata['source']: | |
content = content.replace('**PLACEHOLDER**', '\n').replace('**PLACEHOLDER+SPACE**', '\n ') | |
ref_documents[i].page_content = content | |
# -------------------------------------- | |
# チャンクに分割 | |
texts = text_splitter.split_documents(ref_documents) | |
# -------------------------------------- | |
# multi-e5 モデルの学習環境に合わせて文言を追加 | |
# https://hironsan.hatenablog.com/entry/2023/07/05/073150 | |
# -------------------------------------- | |
if ss.current_embedding == "intfloat/multilingual-e5-large": | |
for i in range(len(texts)): | |
texts[i].page_content = "passage:" + texts[i].page_content | |
# vectordb の初期化 | |
if ss.db is None: | |
ss = initialize_db(ss) | |
# db に埋め込み | |
# ss.db = Chroma.from_documents(texts, ss.embeddings) | |
ss.db.add_documents(documents=texts, embedding=ss.embeddings) | |
return ss | |
def embed_ref(ss: SessionState, urls: str, fileobj: list, header_lim: int, footer_lim: int) -> (SessionState, str): | |
# -------------------------------------- | |
# モデルロード確認 | |
# -------------------------------------- | |
if ss.llm is None or ss.embeddings is None: | |
status_message = "❌ LLM/Embeddingモデルが登録されていません。" | |
return ss, status_message | |
url_flag = "-" | |
pdf_flag = "-" | |
# -------------------------------------- | |
# URLの読み込みとvectordb登録 | |
# -------------------------------------- | |
# URLリストの前処理(リスト化、重複削除、非URL排除) | |
urls = list({url for url in urls.split("\n") if url and "://" in url}) | |
if urls: | |
# 登録済みURL(ss.embedded_urls)との重複を排除。登録済みリストに登録 | |
urls = [url for url in urls if url not in ss.embedded_urls] | |
ss.embedded_urls.extend(urls) | |
# ウェブページの読み込み | |
loader = SeleniumURLLoader(urls=urls) | |
ref_documents = loader.load() | |
# 埋め込み処理の実行 | |
ss = embedding_process(ss, ref_documents) | |
url_flag = "✅ 登録済" | |
# -------------------------------------- | |
# PDFのヘッダーとフッターを除去してvectordb登録 | |
# https://pypdf.readthedocs.io/en/stable/user/extract-text.html | |
# -------------------------------------- | |
if fileobj is None: | |
pass | |
else: | |
# ファイル名リストを取得 | |
pdf_paths = [] | |
for path in fileobj: | |
pdf_paths.append(path.name) | |
# リストの初期化 | |
ref_documents = [] | |
# 各PDFファイルを読み込み | |
for pdf_path in pdf_paths: | |
pdf = PdfReader(pdf_path) | |
body = [] | |
def visitor_body(text, cm, tm, font_dict, font_size): | |
y = tm[5] | |
if y > footer_lim and y < header_lim: # y座標がヘッダーとフッターの間にあるかどうかを確認 | |
parts.append(text) | |
for page in pdf.pages: | |
parts = [] | |
page.extract_text(visitor_text=visitor_body) | |
body.append("".join(parts)) | |
body = "\n".join(body) | |
# パスからファイル名のみを取得 | |
filename = os.path.basename(pdf_path) | |
# 取得テキスト → LangChain ドキュメント変換 | |
ref_documents.append(Document(page_content=body, metadata={"source": filename})) | |
# 埋め込み処理の実行 | |
ss = embedding_process(ss, ref_documents) | |
pdf_flag = "✅ 登録済" | |
langchain.debug=True | |
status_message = "URL: " + url_flag + " / PDF: " + pdf_flag | |
return ss, status_message | |
def clear_db(ss: SessionState) -> (SessionState, str): | |
if ss.db is None: | |
status_message = "❌ 参照データが登録されていません。" | |
return ss, status_message | |
try: | |
ss.db.delete_collection() | |
status_message = "✅ 参照データを削除しました。" | |
except NameError: | |
status_message = "❌ 参照データが登録されていません。" | |
return ss, status_message | |
# ---------------------------------------------------------------------------- | |
# query入力 ▶ [def user] ▶ [ def bot ] ▶ [def show_response] ▶ チャットボット画面 | |
# ⬇ ⬇ ⬆ | |
# チャットボット画面 [qa_predict / conversation_predict] | |
# ---------------------------------------------------------------------------- | |
def user(ss: SessionState, query) -> (SessionState, list): | |
# 会話履歴が一定数を超えた場合は、最初の履歴を削除する | |
if len(ss.dialogue) > 20: | |
ss.dialogue.pop(0) | |
ss.dialogue = ss.dialogue + [(query, None)] # 会話履歴(None はボットの回答欄=空欄) | |
chat_history = ss.dialogue | |
# チャット画面=chat_history | |
return ss, chat_history | |
def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (SessionState, str): | |
original_query = query | |
if ss.llm is None: | |
if ss.dialogue: | |
response = "LLMが設定されていません。設定画面で任意のモデルを選択してください。" | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
return ss, "" | |
elif qa_flag is True and ss.embeddings is None: | |
if ss.dialogue: | |
response = "Embeddingモデルが設定されていません。設定画面で任意のモデルを選択してください。" | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
return ss, "" | |
elif qa_flag is True and ss.db is None: | |
if ss.dialogue: | |
response = "参照データが登録されていません。" | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
return ss, "" | |
# Refine query | |
history = ss.memory.load_memory_variables({}) | |
if history['chat_history'] != "": | |
# チャット履歴からクエリをリファイン | |
query = ss.query_generator({"query": query, "chat_history": history})['text'] | |
# QA Model | |
if qa_flag is True and ss.embeddings is not None and ss.db is not None: | |
if web_flag: | |
ss, web_query = web_search(ss, query) | |
ss = qa_predict(ss, web_query) | |
ss.memory.chat_memory.messages[-2].content = query | |
else: | |
ss = qa_predict(ss, query) | |
# Chat Model | |
else: | |
if web_flag: | |
ss, web_query = web_search(ss, query) | |
ss = chat_predict(ss, web_query) | |
ss.memory.chat_memory.messages[-2].content = query | |
else: | |
ss = chat_predict(ss, query) | |
# GPTモデル利用時はDeepLでメモリを英語化 | |
ss = deepl_memory(ss) | |
return ss, "" # ssとquery欄(空欄) | |
def chat_predict(ss: SessionState, query) -> SessionState: | |
response = ss.conversation_chain.predict(query=query) | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
return ss | |
def qa_predict(ss: SessionState, query) -> SessionState: | |
original_query = query | |
# Rinnaモデル向けの設定(クエリの改行コード修正) | |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
query = query.strip().replace("\n", "<NL>") | |
else: | |
query = query.strip() | |
# multilingual-e5向けのクエリ文言prefix | |
if ss.current_embedding == "intfloat/multilingual-e5-large": | |
db_query_str = "query: " + query | |
else: | |
db_query_str = query | |
# DBから関連文書と出典を抽出 | |
docs = ss.db.similarity_search(db_query_str, k=ss.similarity_search_k) | |
sources= "\n\n[Sources]\n" + '\n - '.join(list(set(doc.metadata['source'] for doc in docs if 'source' in doc.metadata))) | |
# Rinnaモデル向けの設定(抽出文書の改行コード修正) | |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
for i in range(len(docs)): | |
docs[i].page_content = docs[i].page_content.strip().replace("\n", "<NL>") | |
# 回答の生成(最大3回の試行) | |
for _ in range(3): | |
result = ss.qa_chain({"input_documents": docs, "query": query}) | |
result["output_text"] = result["output_text"].replace("<NL>", "\n").strip("...").strip("回答:").strip() | |
# result["output_text"]が空欄でない場合、メモリーを更新して返す | |
if result["output_text"] != "": | |
response = result["output_text"] + sources | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
return ss | |
else: | |
# 空欄の場合は直近の履歴を削除してやり直し | |
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2] | |
# 3回の試行後も空欄の場合 | |
response = "3回試行しましたが、情報製生成できませんでした。" | |
if sources != "": | |
response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。" | |
# ユーザーメッセージと AI メッセージの追加 | |
ss.memory.chat_memory.add_user_message(original_query.replace("<NL>", "\n")) | |
ss.memory.chat_memory.add_ai_message(response) | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴 | |
return ss | |
# 回答を1文字ずつチャット画面に表示する | |
def show_response(ss: SessionState) -> str: | |
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得 | |
if chat_history: | |
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避 | |
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする | |
if response is None: | |
response = "回答を生成できませんでした。" | |
for character in response: | |
chat_history[-1][1] += character | |
time.sleep(0.05) | |
yield chat_history | |
with gr.Blocks() as demo: | |
# ユーザ別セッションメモリのインスタンス化(リロードでリセット) | |
ss = gr.State(SessionState()) | |
# -------------------------------------- | |
# API KEY をセット/クリアする関数 | |
# -------------------------------------- | |
def openai_api_setfn(openai_api_key) -> str: | |
if openai_api_key == "kikagaku": | |
os.environ["OPENAI_API_KEY"] = os.getenv("kikagaku_demo") | |
status_message = "✅ キカガク専用DEMOへようこそ!APIキーを設定しました" | |
return status_message | |
elif not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50: | |
os.environ["OPENAI_API_KEY"] = "" | |
status_message = "❌ 有効なAPIキーを入力してください" | |
return status_message | |
else: | |
os.environ["OPENAI_API_KEY"] = openai_api_key | |
status_message = "✅ APIキーを設定しました" | |
return status_message | |
def openai_api_clsfn(ss) -> (str, str): | |
openai_api_key = "" | |
os.environ["OPENAI_API_KEY"] = "" | |
status_message = "✅ APIキーの削除が完了しました" | |
return status_message, "" | |
with gr.Tabs(): | |
# -------------------------------------- | |
# Setting Tab | |
# -------------------------------------- | |
with gr.TabItem("1. LLM設定"): | |
with gr.Row(): | |
model_id = gr.Dropdown( | |
choices=[ | |
'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct', | |
'rinna/bilingual-gpt-neox-4b-instruction-sft', | |
'gpt-3.5-turbo', | |
], | |
value="gpt-3.5-turbo", | |
label='LLM model', | |
interactive=True, | |
) | |
with gr.Row(): | |
embedding_id = gr.Dropdown( | |
choices=[ | |
'intfloat/multilingual-e5-large', | |
'sonoisa/sentence-bert-base-ja-mean-tokens-v2', | |
'oshizo/sbert-jsnli-luke-japanese-base-lite', | |
'text-embedding-ada-002', | |
"None" | |
], | |
value="text-embedding-ada-002", | |
label = 'Embedding model', | |
interactive=True, | |
) | |
with gr.Row(): | |
with gr.Column(scale=19): | |
openai_api_key = gr.Textbox(label="OpenAI API Key (Optional)", interactive=True, type="password", value="", placeholder="Your OpenAI API Key for OpenAI models.", max_lines=1) | |
with gr.Column(scale=1): | |
openai_api_set = gr.Button(value="Set API KEY", size="sm") | |
openai_api_cls = gr.Button(value="Delete API KEY", size="sm") | |
# with gr.Row(): | |
# reference_libs = gr.CheckboxGroup(choices=['LangChain', 'Gradio'], label="Reference Libraries", interactive=False) | |
# 詳細設定(折りたたみ) | |
with gr.Accordion(label="Advanced Setting", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True) | |
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=True) | |
with gr.Column(): | |
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True) | |
with gr.Column(): | |
similarity_search_k = gr.Slider(label="similarity_search_k (OpenAI, HF)", minimum=1, maximum=10, step=1, value=3, interactive=True) | |
with gr.Column(): | |
summarization_mode = gr.Radio(choices=['stuff', 'map_reduce'], label="Summarization mode", value='stuff', interactive=True) | |
with gr.Column(): | |
min_length = gr.Slider(label="min_length (HF)", minimum=1, maximum=100, step=1, value=10, interactive=True) | |
with gr.Column(): | |
max_new_tokens = gr.Slider(label="max_tokens(OpenAI), max_new_tokens(HF)", minimum=1, maximum=1024, step=1, value=256, interactive=True) | |
with gr.Column(): | |
top_k = gr.Slider(label='top_k (HF)', minimum=1, maximum=100, step=1, value=40, interactive=True) | |
with gr.Column(): | |
top_p = gr.Slider(label='top_p (HF)', minimum=0.01, maximum=0.99, step=0.01, value=0.92, interactive=True) | |
with gr.Column(): | |
repetition_penalty = gr.Slider(label='repetition_penalty (HF)', minimum=0.5, maximum=2, step=0.1, value=1.2, interactive=True) | |
with gr.Column(): | |
num_return_sequences = gr.Slider(label='num_return_sequences (HF)', minimum=1, maximum=20, step=1, value=3, interactive=True) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
config_btn = gr.Button(value="Configure") | |
with gr.Column(scale=13): | |
status_cfg = gr.Textbox(show_label=False, interactive=False, value="モデルを設定してください", container=False, max_lines=1) | |
# ボタン等のアクション設定 | |
openai_api_set.click(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full") | |
openai_api_cls.click(openai_api_clsfn, inputs=[openai_api_key], outputs=[status_cfg, openai_api_key], show_progress="full") | |
openai_api_key.submit(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full") | |
config_btn.click( | |
fn = load_models, | |
inputs = [ss, model_id, embedding_id, openai_api_key, load_in_8bit, verbose, temperature, \ | |
similarity_search_k, summarization_mode, \ | |
min_length, max_new_tokens, top_k, top_p, repetition_penalty, num_return_sequences], | |
outputs = [ss, status_cfg], | |
queue = True, | |
show_progress = "full" | |
) | |
# -------------------------------------- | |
# Reference Tab | |
# -------------------------------------- | |
with gr.TabItem("2. References"): | |
urls = gr.TextArea( | |
max_lines = 60, | |
show_label=False, | |
info = "List any reference URLs for Q&A retrieval.", | |
placeholder = "https://blog.kikagaku.co.jp/deep-learning-transformer\nhttps://note.com/elyza/n/na405acaca130", | |
interactive=True, | |
) | |
with gr.Row(): | |
pdf_paths = gr.File(label="PDFs", height=150, min_width=60, scale=7, file_types=[".pdf"], file_count="multiple", interactive=True) | |
header_lim = gr.Number(label="Header (pt)", step=1, value=792, precision=0, min_width=70, scale=1, interactive=True) | |
footer_lim = gr.Number(label="Footer (pt)", step=1, value=0, precision=0, min_width=70, scale=1, interactive=True) | |
pdf_ref = gr.Textbox(show_label=False, value="A4 Size:\n(下)0-792pt(上)\n *28.35pt/cm", container=False, scale=1, interactive=False) | |
with gr.Row(): | |
ref_set_btn = gr.Button(value="コンテンツ登録", scale=1) | |
ref_clear_btn = gr.Button(value="登録データ削除", scale=1) | |
status_ref = gr.Textbox(show_label=False, interactive=False, value="参照データ未登録", container=False, max_lines=1, scale=18) | |
ref_set_btn.click(fn=embed_ref, inputs=[ss, urls, pdf_paths, header_lim, footer_lim], outputs=[ss, status_ref], queue=True, show_progress="full") | |
ref_clear_btn.click(fn=clear_db, inputs=[ss], outputs=[ss, status_ref], show_progress="full") | |
# -------------------------------------- | |
# Chatbot Tab | |
# -------------------------------------- | |
with gr.TabItem("3. Q&A Chat"): | |
chat_history = gr.Chatbot([], elem_id="chatbot", avatar_images=["bear.png", "penguin.png"],) | |
with gr.Row(): | |
with gr.Column(scale=95): | |
query = gr.Textbox( | |
show_label=False, | |
placeholder="Send a message with [Shift]+[Enter] key.", | |
lines=4, | |
container=False, | |
autofocus=True, | |
interactive=True, | |
) | |
with gr.Column(scale=5): | |
with gr.Row(): | |
qa_flag = gr.Checkbox(label="QA mode", value=False, min_width=60, interactive=True) | |
web_flag = gr.Checkbox(label="Web Search", value=True, min_width=60, interactive=True) | |
with gr.Row(): | |
query_send_btn = gr.Button(value="▶") | |
# gr.Examples(["機械学習について説明してください"], inputs=[query]) | |
query.submit( | |
user, [ss, query], [ss, chat_history] | |
).then( | |
bot, [ss, query, qa_flag, web_flag, summarization_mode], [ss, query] | |
).then( | |
show_response, [ss], [chat_history] | |
) | |
query_send_btn.click( | |
user, [ss, query], [ss, chat_history] | |
).then( | |
bot, [ss, query, qa_flag, web_flag, summarization_mode], [ss, query] | |
).then( | |
show_response, [ss], [chat_history] | |
) | |
if __name__ == "__main__": | |
demo.queue(concurrency_count=5) | |
demo.launch(debug=True) | |