Spaces:
Running
Running
Upload 5 files
Browse files- app.py +1022 -0
- bear.png +0 -0
- packages.txt +1 -0
- penguin.png +0 -0
- requirements.txt +23 -0
app.py
ADDED
@@ -0,0 +1,1022 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Chat with Documents 2 by cawacci
|
2 |
+
# 2023.9.10 キカガク長期コース(2023年4月期)の成果物アプリとして制作
|
3 |
+
|
4 |
+
# --------------------------------------
|
5 |
+
# Libraries
|
6 |
+
# --------------------------------------
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import gc # メモリ解放
|
10 |
+
import re # 正規表現で文章をクリーンアップ
|
11 |
+
|
12 |
+
# HuggingFace
|
13 |
+
import torch
|
14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
+
|
16 |
+
# OpenAI
|
17 |
+
import openai
|
18 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
19 |
+
from langchain.chat_models import ChatOpenAI
|
20 |
+
|
21 |
+
# LangChain
|
22 |
+
from langchain.llms import HuggingFacePipeline
|
23 |
+
from transformers import pipeline
|
24 |
+
|
25 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
26 |
+
from langchain.chains import LLMChain, VectorDBQA
|
27 |
+
from langchain.vectorstores import Chroma
|
28 |
+
|
29 |
+
from langchain import PromptTemplate, ConversationChain
|
30 |
+
from langchain.chains.question_answering import load_qa_chain # QA Chat
|
31 |
+
from langchain.document_loaders import SeleniumURLLoader # URL取得
|
32 |
+
from langchain.docstore.document import Document # テキストをドキュメント化
|
33 |
+
from langchain.memory import ConversationSummaryBufferMemory # チャット履歴
|
34 |
+
|
35 |
+
from typing import Any
|
36 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
37 |
+
|
38 |
+
from langchain.tools import DuckDuckGoSearchRun
|
39 |
+
|
40 |
+
# Gradio
|
41 |
+
import gradio as gr
|
42 |
+
from pypdf import PdfReader
|
43 |
+
import requests # DeepL API request
|
44 |
+
|
45 |
+
# test
|
46 |
+
import langchain # (debug=Trueにするため)
|
47 |
+
|
48 |
+
# --------------------------------------
|
49 |
+
# ユーザ別セッションの変数値を記録するクラス
|
50 |
+
# (参考)https://blog.shikoan.com/gradio-state/
|
51 |
+
# --------------------------------------
|
52 |
+
class SessionState:
|
53 |
+
def __init__(self):
|
54 |
+
# Hugging Face
|
55 |
+
self.tokenizer = None
|
56 |
+
self.pipe = None
|
57 |
+
self.model = None
|
58 |
+
|
59 |
+
# LangChain
|
60 |
+
self.llm = None
|
61 |
+
self.embeddings = None
|
62 |
+
self.current_model = ""
|
63 |
+
self.current_embedding = ""
|
64 |
+
self.db = None # Vector DB
|
65 |
+
self.memory = None # Langchain Chat Memory
|
66 |
+
self.conversation_chain = None # ConversationChain
|
67 |
+
self.query_generator = None # Query Refiner with Chat history
|
68 |
+
self.qa_chain = None # load_qa_chain
|
69 |
+
self.embedded_urls = []
|
70 |
+
self.similarity_search_k = None # No. of similarity search documents to find.
|
71 |
+
self.summarization_mode = None # Stuff / Map Reduce / Refine
|
72 |
+
|
73 |
+
# Apps
|
74 |
+
self.dialogue = [] # Recent Chat History for display
|
75 |
+
|
76 |
+
# --------------------------------------
|
77 |
+
# Empty Cache
|
78 |
+
# --------------------------------------
|
79 |
+
def cache_clear(self):
|
80 |
+
if torch.cuda.is_available():
|
81 |
+
torch.cuda.empty_cache() # GPU Memory Clear
|
82 |
+
|
83 |
+
gc.collect() # CPU Memory Clear
|
84 |
+
|
85 |
+
# --------------------------------------
|
86 |
+
# Clear Models (llm: llm model, embd: embeddings, db: vectordb)
|
87 |
+
# --------------------------------------
|
88 |
+
def clear_memory(self, llm=False, embd=False, db=False):
|
89 |
+
# DB
|
90 |
+
if db and self.db:
|
91 |
+
self.db.delete_collection()
|
92 |
+
self.db = None
|
93 |
+
self.embedded_urls = []
|
94 |
+
|
95 |
+
# Embeddings model
|
96 |
+
if llm or embd:
|
97 |
+
self.embeddings = None
|
98 |
+
self.current_embedding = ""
|
99 |
+
self.qa_chain = None
|
100 |
+
|
101 |
+
# LLM model
|
102 |
+
if llm:
|
103 |
+
self.llm = None
|
104 |
+
self.pipe = None
|
105 |
+
self.model = None
|
106 |
+
self.current_model = ""
|
107 |
+
self.tokenizer = None
|
108 |
+
self.memory = None
|
109 |
+
self.chat_history = [] # ←必要性を要検証
|
110 |
+
|
111 |
+
self.cache_clear()
|
112 |
+
|
113 |
+
# --------------------------------------
|
114 |
+
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
115 |
+
# (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
|
116 |
+
# → 「!」、「?」、「)」、「.」、「!」、「?」、「,」などを追加
|
117 |
+
# --------------------------------------
|
118 |
+
class JPTextSplitter(RecursiveCharacterTextSplitter):
|
119 |
+
def __init__(self, **kwargs: Any):
|
120 |
+
separators = ["\n\n", "\n", "。", "!", "?", ")","、", ".", "!", "?", ",", " ", ""]
|
121 |
+
super().__init__(separators=separators, **kwargs)
|
122 |
+
|
123 |
+
# チャンクの分割
|
124 |
+
chunk_size = 512
|
125 |
+
chunk_overlap = 35
|
126 |
+
|
127 |
+
text_splitter = JPTextSplitter(
|
128 |
+
chunk_size = chunk_size, # チャンクの最大文字数
|
129 |
+
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
|
130 |
+
)
|
131 |
+
|
132 |
+
# --------------------------------------
|
133 |
+
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
134 |
+
# --------------------------------------
|
135 |
+
DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
|
136 |
+
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
|
137 |
+
|
138 |
+
def deepl_memory(ss: SessionState) -> (SessionState):
|
139 |
+
if ss.current_model == "gpt-3.5-turbo":
|
140 |
+
# メモリから会話履歴を取得
|
141 |
+
user_message = ss.memory.chat_memory.messages[-2].content
|
142 |
+
ai_message = ss.memory.chat_memory.messages[-1].content
|
143 |
+
text = [user_message, ai_message]
|
144 |
+
|
145 |
+
# DeepL設定
|
146 |
+
params = {
|
147 |
+
"auth_key": DEEPL_API_KEY,
|
148 |
+
"text": text,
|
149 |
+
"target_lang": "EN",
|
150 |
+
"source_lang": "JA",
|
151 |
+
"tag_handling": "xml",
|
152 |
+
"igonere_tags": "x",
|
153 |
+
}
|
154 |
+
request = requests.post(DEEPL_API_ENDPOINT, data=params)
|
155 |
+
request.raise_for_status() # 応答のステータスコードがエラーの場合は例外を発生させます。
|
156 |
+
response = request.json()
|
157 |
+
|
158 |
+
# JSONから翻訳文を取得
|
159 |
+
user_message = response["translations"][0]["text"]
|
160 |
+
ai_message = response["translations"][1]["text"]
|
161 |
+
|
162 |
+
# memoryの最後の会話を削除し、翻訳文を追加
|
163 |
+
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2]
|
164 |
+
ss.memory.chat_memory.add_user_message(user_message)
|
165 |
+
ss.memory.chat_memory.add_ai_message(ai_message)
|
166 |
+
|
167 |
+
return ss
|
168 |
+
|
169 |
+
# --------------------------------------
|
170 |
+
# DuckDuckGo Web検索結果を入力プロンプトに追加
|
171 |
+
# --------------------------------------
|
172 |
+
# DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
|
173 |
+
# DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
|
174 |
+
|
175 |
+
def web_search(query, current_model) -> str:
|
176 |
+
search = DuckDuckGoSearchRun()
|
177 |
+
web_result = search(query)
|
178 |
+
|
179 |
+
if current_model == "gpt-3.5-turbo":
|
180 |
+
text = [query, web_result]
|
181 |
+
params = {
|
182 |
+
"auth_key": DEEPL_API_KEY,
|
183 |
+
"text": text,
|
184 |
+
"target_lang": "EN",
|
185 |
+
"source_lang": "JA",
|
186 |
+
"tag_handling": "xml",
|
187 |
+
"igonere_tags": "x",
|
188 |
+
}
|
189 |
+
request = requests.post(DEEPL_API_ENDPOINT, data=params)
|
190 |
+
response = request.json()
|
191 |
+
|
192 |
+
query = response["translations"][0]["text"]
|
193 |
+
web_result = response["translations"][1]["text"]
|
194 |
+
|
195 |
+
web_query = query + "\nUse the following information as a reference to answer the question above in the Japanese.\n===\nReference: " + web_result + "\n==="
|
196 |
+
|
197 |
+
return web_query
|
198 |
+
|
199 |
+
# --------------------------------------
|
200 |
+
# LangChain カスタムプロンプト各種
|
201 |
+
# llama tokenizer
|
202 |
+
# https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
|
203 |
+
|
204 |
+
# OpenAI tokenizer
|
205 |
+
# https://platform.openai.com/tokenizer
|
206 |
+
# --------------------------------------
|
207 |
+
|
208 |
+
# --------------------------------------
|
209 |
+
# Conversation Chain Template
|
210 |
+
# --------------------------------------
|
211 |
+
|
212 |
+
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
213 |
+
sys_chat_message = """
|
214 |
+
The following is a conversation between an AI concierge and a customer.
|
215 |
+
The AI understands what the customer wants to know from the conversation history and the latest question,
|
216 |
+
and gives many specific details in Japanese. If the AI does not know the answer to a question, it does not
|
217 |
+
make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
|
218 |
+
""".replace("\n", "")
|
219 |
+
|
220 |
+
chat_common_format = """
|
221 |
+
===
|
222 |
+
Question: {query}
|
223 |
+
|
224 |
+
Conversation History:
|
225 |
+
{chat_history}
|
226 |
+
|
227 |
+
日本語の回答: """
|
228 |
+
|
229 |
+
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
230 |
+
chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common_format}[/INST]"
|
231 |
+
|
232 |
+
# --------------------------------------
|
233 |
+
# QA Chain Template (Stuff)
|
234 |
+
# --------------------------------------
|
235 |
+
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
236 |
+
sys_qa_message = """
|
237 |
+
You are an AI concierge who carefully answers questions from customers based on references.
|
238 |
+
You understand what the customer wants to know from the Conversation History and Question,
|
239 |
+
and give a specific answer in Japanese using sentences extracted from the following references.
|
240 |
+
If you do not know the answer, do not make up an answer and reply,
|
241 |
+
"誠に申し訳ございませんが、その点についてはわかりかねます".
|
242 |
+
""".replace("\n", "")
|
243 |
+
|
244 |
+
qa_common_format = """
|
245 |
+
===
|
246 |
+
Question: {query}
|
247 |
+
References: {context}
|
248 |
+
Conversation History:
|
249 |
+
{chat_history}
|
250 |
+
|
251 |
+
日本語の回答: """
|
252 |
+
|
253 |
+
qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
254 |
+
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
255 |
+
|
256 |
+
# --------------------------------------
|
257 |
+
# QA Chain Template (Map Reduce)
|
258 |
+
# --------------------------------------
|
259 |
+
# 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
|
260 |
+
query_generator_message = """
|
261 |
+
Referring to the "Conversation History", reformat the user's "Additional Question"
|
262 |
+
to a specific question in Japanese by filling in the missing subject, verb, objects,
|
263 |
+
complements, and other necessary information to get a better search result.
|
264 |
+
""".replace("\n", "")
|
265 |
+
|
266 |
+
query_generator_common_format = """
|
267 |
+
===
|
268 |
+
[Conversation History]
|
269 |
+
{chat_history}
|
270 |
+
|
271 |
+
[Additional Question] {query}
|
272 |
+
明確な質問文: """
|
273 |
+
|
274 |
+
query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
|
275 |
+
query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
|
276 |
+
|
277 |
+
|
278 |
+
# 2. 生成された質問文を用いて、参考文献を要約するchain のプロンプト
|
279 |
+
question_prompt_message = """
|
280 |
+
From the following references, extract key information relevant to the question
|
281 |
+
and summarize it in a natural English sentence with clear subject, verb, object,
|
282 |
+
and complement.
|
283 |
+
""".replace("\n", "")
|
284 |
+
|
285 |
+
question_prompt_common_format = """
|
286 |
+
===
|
287 |
+
[references] {context}
|
288 |
+
[Question] {query}
|
289 |
+
[Summary] """
|
290 |
+
|
291 |
+
question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
|
292 |
+
question_prompt_template_llama2 = f"<s>[INST] <<SYS>>{question_prompt_message}<</SYS>>{question_prompt_common_format}[/INST]"
|
293 |
+
|
294 |
+
|
295 |
+
# 3. 生成された質問文とベクターデータベースの要約をもとに、回答を行うchain のプロンプト
|
296 |
+
combine_prompt_message = """
|
297 |
+
You are an AI concierge who carefully answers questions from customers based on references.
|
298 |
+
Provide a specific answer in Japanese using sentences extracted from the following references.
|
299 |
+
If you do not know the answer, do not make up an answer and reply,
|
300 |
+
"誠に申し訳ございませんが、その点についてはわかりかねます".
|
301 |
+
""".replace("\n", "")
|
302 |
+
|
303 |
+
combine_prompt_common_format = """
|
304 |
+
===
|
305 |
+
Question:
|
306 |
+
{query}
|
307 |
+
===
|
308 |
+
Reference: {summaries}
|
309 |
+
===
|
310 |
+
日本語の回答: """
|
311 |
+
|
312 |
+
combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}"
|
313 |
+
combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]"
|
314 |
+
|
315 |
+
|
316 |
+
# --------------------------------------
|
317 |
+
# ConversationSummaryBufferMemoryの要約プロンプト
|
318 |
+
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
|
319 |
+
# --------------------------------------
|
320 |
+
# Tokens: OpenAI 212/ Llama 214 <- In Japanese: Tokens: OpenAI 397/ Llama 297
|
321 |
+
conversation_summary_template = """
|
322 |
+
Using the example as a guide, compose a summary in English that gives an overview of the conversation by summarizing the "current summary" and the "new conversation".
|
323 |
+
===
|
324 |
+
Example
|
325 |
+
[Current Summary] Customer asks AI what it thinks about Artificial Intelligence, AI says Artificial Intelligence is a good tool.
|
326 |
+
|
327 |
+
[New Conversation]
|
328 |
+
Human: なぜ人工知能が良いツールだと思いますか?
|
329 |
+
AI: 人工知能は「人間の可能性を最大限に引き出すことを助ける」からです。
|
330 |
+
|
331 |
+
[New Summary] Customer asks what you think about Artificial Intelligence, and AI responds that it is a good force that helps humans reach their full potential.
|
332 |
+
===
|
333 |
+
[Current Summary] {summary}
|
334 |
+
|
335 |
+
[New Conversation]
|
336 |
+
{new_lines}
|
337 |
+
|
338 |
+
[New Summary]
|
339 |
+
""".strip()
|
340 |
+
|
341 |
+
# モデル読み込み
|
342 |
+
def load_models(
|
343 |
+
ss: SessionState,
|
344 |
+
model_id: str,
|
345 |
+
embedding_id: str,
|
346 |
+
openai_api_key: str,
|
347 |
+
load_in_8bit: bool,
|
348 |
+
verbose: bool,
|
349 |
+
temperature: float,
|
350 |
+
similarity_search_k: int,
|
351 |
+
summarization_mode: str,
|
352 |
+
min_length: int,
|
353 |
+
max_new_tokens: int,
|
354 |
+
top_k: int,
|
355 |
+
top_p: float,
|
356 |
+
repetition_penalty: float,
|
357 |
+
num_return_sequences: int,
|
358 |
+
) -> (SessionState, str):
|
359 |
+
|
360 |
+
# --------------------------------------
|
361 |
+
# 変数の保存
|
362 |
+
# --------------------------------------
|
363 |
+
ss.similarity_search_k = similarity_search_k
|
364 |
+
ss.summarization_mode = summarization_mode
|
365 |
+
|
366 |
+
# --------------------------------------
|
367 |
+
# OpenAI API KEYの確認
|
368 |
+
# --------------------------------------
|
369 |
+
if (model_id == "gpt-3.5-turbo" or embedding_id == "text-embedding-ada-002"):
|
370 |
+
# 前処理
|
371 |
+
if not os.environ["OPENAI_API_KEY"]:
|
372 |
+
status_message = "❌ OpenAI API KEY を設定してください"
|
373 |
+
return ss, status_message
|
374 |
+
|
375 |
+
# --------------------------------------
|
376 |
+
# LLMの設定
|
377 |
+
# --------------------------------------
|
378 |
+
# OpenAI Model
|
379 |
+
if model_id == "gpt-3.5-turbo":
|
380 |
+
ss.clear_memory(llm=True, db=True)
|
381 |
+
ss.llm = ChatOpenAI(
|
382 |
+
model_name = model_id,
|
383 |
+
temperature = temperature,
|
384 |
+
verbose = verbose,
|
385 |
+
max_tokens = max_new_tokens,
|
386 |
+
)
|
387 |
+
|
388 |
+
# Hugging Face GPT Model
|
389 |
+
else:
|
390 |
+
ss.clear_memory(llm=True, db=True)
|
391 |
+
|
392 |
+
if model_id == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
393 |
+
ss.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
394 |
+
else:
|
395 |
+
ss.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
396 |
+
|
397 |
+
ss.model = AutoModelForCausalLM.from_pretrained(
|
398 |
+
model_id,
|
399 |
+
load_in_8bit = load_in_8bit,
|
400 |
+
torch_dtype = torch.float16,
|
401 |
+
device_map = "auto",
|
402 |
+
)
|
403 |
+
|
404 |
+
ss.pipe = pipeline(
|
405 |
+
"text-generation",
|
406 |
+
model = ss.model,
|
407 |
+
tokenizer = ss.tokenizer,
|
408 |
+
min_length = min_length,
|
409 |
+
max_new_tokens = max_new_tokens,
|
410 |
+
do_sample = True,
|
411 |
+
top_k = top_k,
|
412 |
+
top_p = top_p,
|
413 |
+
repetition_penalty = repetition_penalty,
|
414 |
+
num_return_sequences = num_return_sequences,
|
415 |
+
temperature = temperature,
|
416 |
+
)
|
417 |
+
ss.llm = HuggingFacePipeline(pipeline=ss.pipe)
|
418 |
+
|
419 |
+
# --------------------------------------
|
420 |
+
# 埋め込みモデルの設定
|
421 |
+
# --------------------------------------
|
422 |
+
if ss.current_embedding == embedding_id:
|
423 |
+
pass
|
424 |
+
|
425 |
+
else:
|
426 |
+
# Reset embeddings and vectordb
|
427 |
+
ss.clear_memory(embd=True, db=True)
|
428 |
+
|
429 |
+
if embedding_id == "None":
|
430 |
+
pass
|
431 |
+
|
432 |
+
# OpenAI
|
433 |
+
elif embedding_id == "text-embedding-ada-002":
|
434 |
+
ss.embeddings = OpenAIEmbeddings()
|
435 |
+
|
436 |
+
# Hugging Face
|
437 |
+
else:
|
438 |
+
ss.embeddings = HuggingFaceEmbeddings(model_name=embedding_id)
|
439 |
+
|
440 |
+
# --------------------------------------
|
441 |
+
# チェーンの設定
|
442 |
+
#---------------------------------------
|
443 |
+
ss = set_chains(ss, summarization_mode)
|
444 |
+
|
445 |
+
# --------------------------------------
|
446 |
+
# 現在のモデル名を SessionStateオブジェクトに保存
|
447 |
+
#---------------------------------------
|
448 |
+
ss.current_model = model_id
|
449 |
+
ss.current_embedding = embedding_id
|
450 |
+
|
451 |
+
# Status Message
|
452 |
+
status_message = "✅ LLM: " + ss.current_model + ", embeddings: " + ss.current_embedding
|
453 |
+
|
454 |
+
return ss, status_message
|
455 |
+
|
456 |
+
# --------------------------------------
|
457 |
+
# Conversation/QA Chain 呼び出し統合
|
458 |
+
# --------------------------------------
|
459 |
+
def set_chains(ss: SessionState, summarization_mode) -> SessionState:
|
460 |
+
|
461 |
+
# モデルに合わせて chat_template を設定
|
462 |
+
human_prefix = "Human: "
|
463 |
+
ai_prefix = "AI: "
|
464 |
+
chat_template = chat_template_std
|
465 |
+
qa_template = qa_template_std
|
466 |
+
query_generator_template = query_generator_template_std
|
467 |
+
question_template = question_prompt_template_std
|
468 |
+
combine_template = combine_prompt_template_std
|
469 |
+
|
470 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
471 |
+
# Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照)
|
472 |
+
chat_template = chat_template.replace("\n", "<NL>")
|
473 |
+
qa_template = qa_template.replace("\n", "<NL>")
|
474 |
+
query_generator_template = query_generator_template_std.replace("\n", "<NL>")
|
475 |
+
question_template = question_prompt_template_std.replace("\n", "<NL>")
|
476 |
+
combine_template = combine_prompt_template_std.replace("\n", "<NL>")
|
477 |
+
human_prefix = "ユーザー: "
|
478 |
+
ai_prefix = "システム: "
|
479 |
+
|
480 |
+
elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"):
|
481 |
+
# ELYZAモデル向けのテンプレート設定
|
482 |
+
chat_template = chat_template_llama2
|
483 |
+
qa_template = qa_template_llama2
|
484 |
+
query_generator_template = query_generator_template_llama2
|
485 |
+
question_template = question_prompt_template_llama2
|
486 |
+
combine_template = combine_prompt_template_llama2
|
487 |
+
|
488 |
+
# --------------------------------------
|
489 |
+
# メモリの設定
|
490 |
+
# --------------------------------------
|
491 |
+
if ss.memory is None:
|
492 |
+
conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template)
|
493 |
+
ss.memory = ConversationSummaryBufferMemory(
|
494 |
+
llm = ss.llm,
|
495 |
+
memory_key = "chat_history",
|
496 |
+
input_key = "query",
|
497 |
+
output_key = "output_text",
|
498 |
+
return_messages = False,
|
499 |
+
human_prefix = human_prefix,
|
500 |
+
ai_prefix = ai_prefix,
|
501 |
+
max_token_limit = 1024,
|
502 |
+
prompt = conversation_summary_prompt,
|
503 |
+
)
|
504 |
+
|
505 |
+
# --------------------------------------
|
506 |
+
# Conversation/QAチェーンの設定
|
507 |
+
# --------------------------------------
|
508 |
+
if ss.conversation_chain is None:
|
509 |
+
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
510 |
+
ss.conversation_chain = ConversationChain(
|
511 |
+
llm = ss.llm,
|
512 |
+
prompt = chat_prompt,
|
513 |
+
memory = ss.memory,
|
514 |
+
input_key = "query",
|
515 |
+
output_key = "output_text",
|
516 |
+
verbose = True,
|
517 |
+
)
|
518 |
+
|
519 |
+
if ss.qa_chain is None:
|
520 |
+
if summarization_mode == "stuff":
|
521 |
+
qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
|
522 |
+
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
|
523 |
+
|
524 |
+
elif summarization_mode == "map_reduce":
|
525 |
+
query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"])
|
526 |
+
ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt)
|
527 |
+
|
528 |
+
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
529 |
+
combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
|
530 |
+
ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
|
531 |
+
|
532 |
+
return ss
|
533 |
+
|
534 |
+
def initialize_db(ss: SessionState) -> SessionState:
|
535 |
+
|
536 |
+
# client = chromadb.PersistentClient(path="./db")
|
537 |
+
ss.db = Chroma(
|
538 |
+
collection_name = "user_reference",
|
539 |
+
embedding_function = ss.embeddings,
|
540 |
+
# client = client
|
541 |
+
)
|
542 |
+
|
543 |
+
return ss
|
544 |
+
|
545 |
+
def embedding_process(ss: SessionState, ref_documents: Document) -> SessionState:
|
546 |
+
|
547 |
+
# --------------------------------------
|
548 |
+
# 文章構成と不要な文字列の削除
|
549 |
+
# --------------------------------------
|
550 |
+
for i in range(len(ref_documents)):
|
551 |
+
content = ref_documents[i].page_content.strip()
|
552 |
+
|
553 |
+
# --------------------------------------
|
554 |
+
# PDFの場合は読み取りエラー対策で文書修正を強めに実施
|
555 |
+
# --------------------------------------
|
556 |
+
if ".pdf" in ref_documents[i].metadata['source']:
|
557 |
+
pdf_replacement_sets = [
|
558 |
+
('\n ', '**PLACEHOLDER+SPACE**'),
|
559 |
+
('\n\u3000', '**PLACEHOLDER+SPACE**'),
|
560 |
+
('.\n', '。**PLACEHOLDER**'),
|
561 |
+
(',\n', '。**PLACEHOLDER**'),
|
562 |
+
('?\n', '。**PLACEHOLDER**'),
|
563 |
+
('!\n', '。**PLACEHOLDER**'),
|
564 |
+
('!\n', '。**PLACEHOLDER**'),
|
565 |
+
('。\n', '。**PLACEHOLDER**'),
|
566 |
+
('!\n', '!**PLACEHOLDER**'),
|
567 |
+
(')\n', '!**PLACEHOLDER**'),
|
568 |
+
(']\n', '!**PLACEHOLDER**'),
|
569 |
+
('?\n', '?**PLACEHOLDER**'),
|
570 |
+
(')\n', '?**PLACEHOLDER**'),
|
571 |
+
('】\n', '?**PLACEHOLDER**'),
|
572 |
+
]
|
573 |
+
for original, replacement in pdf_replacement_sets:
|
574 |
+
content = content.replace(original, replacement)
|
575 |
+
content = content.replace(" ", "")
|
576 |
+
# --------------------------------------
|
577 |
+
|
578 |
+
# 不要文字列・空白の削除
|
579 |
+
remove_texts = ["\n", "\r", " "]
|
580 |
+
for remove_text in remove_texts:
|
581 |
+
content = content.replace(remove_text, "")
|
582 |
+
|
583 |
+
# タブや連続空白をシングルスペースに変換
|
584 |
+
replace_texts = ["\t", "\u3000"]
|
585 |
+
for replace_text in replace_texts:
|
586 |
+
content = content.replace(replace_text, " ")
|
587 |
+
|
588 |
+
# PDFの正当な改行をもとに戻す。
|
589 |
+
if ".pdf" in ref_documents[i].metadata['source']:
|
590 |
+
content = content.replace('**PLACEHOLDER**', '\n').replace('**PLACEHOLDER+SPACE**', '\n ')
|
591 |
+
|
592 |
+
ref_documents[i].page_content = content
|
593 |
+
|
594 |
+
# --------------------------------------
|
595 |
+
# チャンクに分割
|
596 |
+
texts = text_splitter.split_documents(ref_documents)
|
597 |
+
|
598 |
+
# --------------------------------------
|
599 |
+
# multi-e5 モデルの学習環境に合わせて文言を追加
|
600 |
+
# https://hironsan.hatenablog.com/entry/2023/07/05/073150
|
601 |
+
# --------------------------------------
|
602 |
+
if ss.current_embedding == "intfloat/multilingual-e5-large":
|
603 |
+
for i in range(len(texts)):
|
604 |
+
texts[i].page_content = "passage:" + texts[i].page_content
|
605 |
+
|
606 |
+
# vectordb の初期化
|
607 |
+
if ss.db is None:
|
608 |
+
ss = initialize_db(ss)
|
609 |
+
|
610 |
+
# db に埋め込み
|
611 |
+
# ss.db = Chroma.from_documents(texts, ss.embeddings)
|
612 |
+
ss.db.add_documents(documents=texts, embedding=ss.embeddings)
|
613 |
+
|
614 |
+
return ss
|
615 |
+
|
616 |
+
def embed_ref(ss: SessionState, urls: str, fileobj: list, header_lim: int, footer_lim: int) -> (SessionState, str):
|
617 |
+
|
618 |
+
# --------------------------------------
|
619 |
+
# モデルロード確認
|
620 |
+
# --------------------------------------
|
621 |
+
if ss.llm is None or ss.embeddings is None:
|
622 |
+
status_message = "❌ LLM/Embeddingモデルが登録されていません。"
|
623 |
+
return ss, status_message
|
624 |
+
|
625 |
+
url_flag = "-"
|
626 |
+
pdf_flag = "-"
|
627 |
+
|
628 |
+
# --------------------------------------
|
629 |
+
# URLの読み込みとvectordb登録
|
630 |
+
# --------------------------------------
|
631 |
+
|
632 |
+
# URLリストの前処理(リスト化、重複削除、非URL排除)
|
633 |
+
urls = list({url for url in urls.split("\n") if url and "://" in url})
|
634 |
+
|
635 |
+
if urls:
|
636 |
+
# 登録済みURL(ss.embedded_urls)との重複を排除。登録済みリストに登録
|
637 |
+
urls = [url for url in urls if url not in ss.embedded_urls]
|
638 |
+
ss.embedded_urls.extend(urls)
|
639 |
+
|
640 |
+
# ウェブページの読み込み
|
641 |
+
loader = SeleniumURLLoader(urls=urls)
|
642 |
+
ref_documents = loader.load()
|
643 |
+
|
644 |
+
# 埋め込み処理の実行
|
645 |
+
ss = embedding_process(ss, ref_documents)
|
646 |
+
|
647 |
+
url_flag = "✅ 登録済"
|
648 |
+
|
649 |
+
# --------------------------------------
|
650 |
+
# PDFのヘッダーとフッターを除去してvectordb登録
|
651 |
+
# https://pypdf.readthedocs.io/en/stable/user/extract-text.html
|
652 |
+
# --------------------------------------
|
653 |
+
|
654 |
+
if fileobj is None:
|
655 |
+
pass
|
656 |
+
|
657 |
+
else:
|
658 |
+
# ファイル名リストを取得
|
659 |
+
pdf_paths = []
|
660 |
+
for path in fileobj:
|
661 |
+
pdf_paths.append(path.name)
|
662 |
+
|
663 |
+
# リストの初期化
|
664 |
+
ref_documents = []
|
665 |
+
|
666 |
+
# 各PDFファイルを読み込み
|
667 |
+
for pdf_path in pdf_paths:
|
668 |
+
pdf = PdfReader(pdf_path)
|
669 |
+
body = []
|
670 |
+
|
671 |
+
def visitor_body(text, cm, tm, font_dict, font_size):
|
672 |
+
y = tm[5]
|
673 |
+
if y > footer_lim and y < header_lim: # y座標がヘッダーとフッターの間にあるかどうかを確認
|
674 |
+
parts.append(text)
|
675 |
+
|
676 |
+
for page in pdf.pages:
|
677 |
+
parts = []
|
678 |
+
page.extract_text(visitor_text=visitor_body)
|
679 |
+
body.append("".join(parts))
|
680 |
+
|
681 |
+
body = "\n".join(body)
|
682 |
+
|
683 |
+
# パスからファイル名のみを取得
|
684 |
+
filename = os.path.basename(pdf_path)
|
685 |
+
# 取得テキスト → LangChain ドキュメント変換
|
686 |
+
ref_documents.append(Document(page_content=body, metadata={"source": filename}))
|
687 |
+
|
688 |
+
# 埋め込み処理の実行
|
689 |
+
ss = embedding_process(ss, ref_documents)
|
690 |
+
|
691 |
+
pdf_flag = "✅ 登録済"
|
692 |
+
|
693 |
+
|
694 |
+
langchain.debug=True
|
695 |
+
|
696 |
+
status_message = "URL: " + url_flag + " / PDF: " + pdf_flag
|
697 |
+
return ss, status_message
|
698 |
+
|
699 |
+
def clear_db(ss: SessionState) -> (SessionState, str):
|
700 |
+
if ss.db is None:
|
701 |
+
status_message = "❌ 参照データが登録されていません。"
|
702 |
+
return ss, status_message
|
703 |
+
|
704 |
+
try:
|
705 |
+
ss.db.delete_collection()
|
706 |
+
status_message = "✅ 参照データを削除しました。"
|
707 |
+
|
708 |
+
except NameError:
|
709 |
+
status_message = "❌ 参照データが登録されていません。"
|
710 |
+
|
711 |
+
return ss, status_message
|
712 |
+
|
713 |
+
# ----------------------------------------------------------------------------
|
714 |
+
# query入力 ▶ [def user] ▶ [ def bot ] ▶ [def show_response] ▶ チャットボット画面
|
715 |
+
# ⬇ ⬇ ⬆
|
716 |
+
# チャットボット画面 [qa_predict / conversation_predict]
|
717 |
+
# ----------------------------------------------------------------------------
|
718 |
+
|
719 |
+
def user(ss: SessionState, query) -> (SessionState, list):
|
720 |
+
# 会話履歴が一定数を超えた場合は、最初の履歴を削除する
|
721 |
+
if len(ss.dialogue) > 20:
|
722 |
+
ss.dialogue.pop(0)
|
723 |
+
|
724 |
+
ss.dialogue = ss.dialogue + [(query, None)] # 会話履歴(None はボットの回答欄=空欄)
|
725 |
+
chat_history = ss.dialogue
|
726 |
+
|
727 |
+
# チャット画面=chat_history
|
728 |
+
return ss, chat_history
|
729 |
+
|
730 |
+
def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (SessionState, str):
|
731 |
+
|
732 |
+
original_query = query
|
733 |
+
|
734 |
+
if ss.llm is None:
|
735 |
+
response = "LLMが設定されていません。設定画面で任意のモデルを選択してください。"
|
736 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
737 |
+
return ss, ""
|
738 |
+
|
739 |
+
elif qa_flag is True and ss.embeddings is None:
|
740 |
+
response = "Embeddingモデルが設定されていません。設定画面で任意のモデルを選択してください。"
|
741 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
742 |
+
|
743 |
+
elif qa_flag is True and ss.db is None:
|
744 |
+
response = "参照データが登録されていません。"
|
745 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
746 |
+
|
747 |
+
# Refine query
|
748 |
+
history = ss.memory.load_memory_variables({})
|
749 |
+
if history['chat_history'] != "":
|
750 |
+
# チャット履歴からクエリをリファイン
|
751 |
+
query = ss.query_generator({"query": query, "chat_history": history})['text']
|
752 |
+
|
753 |
+
# QA Model
|
754 |
+
if qa_flag is True and ss.embeddings is not None and ss.db is not None:
|
755 |
+
if web_flag:
|
756 |
+
web_query = web_search(query, ss.current_model)
|
757 |
+
ss = qa_predict(ss, web_query)
|
758 |
+
ss.memory.chat_memory.messages[-2].content = query
|
759 |
+
else:
|
760 |
+
ss = qa_predict(ss, query) # LLMで回答を生成
|
761 |
+
|
762 |
+
# Chat Model
|
763 |
+
else:
|
764 |
+
if web_flag:
|
765 |
+
web_query = web_search(query, ss.current_model)
|
766 |
+
ss = chat_predict(ss, web_query)
|
767 |
+
ss.memory.chat_memory.messages[-2].content = query
|
768 |
+
else:
|
769 |
+
ss = chat_predict(ss, query)
|
770 |
+
|
771 |
+
# GPTモデル利用時はDeepLでメモリを英語化
|
772 |
+
ss = deepl_memory(ss)
|
773 |
+
|
774 |
+
return ss, "" # ssとquery欄(空欄)
|
775 |
+
|
776 |
+
def chat_predict(ss: SessionState, query) -> SessionState:
|
777 |
+
response = ss.conversation_chain.predict(query=query)
|
778 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
779 |
+
return ss
|
780 |
+
|
781 |
+
def qa_predict(ss: SessionState, query) -> SessionState:
|
782 |
+
|
783 |
+
# Rinnaモデル向けの設定(クエリの改行コード修正)
|
784 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
785 |
+
query = query.strip().replace("\n", "<NL>")
|
786 |
+
else:
|
787 |
+
query = query.strip()
|
788 |
+
|
789 |
+
# multilingual-e5向けのクエリ文言prefix
|
790 |
+
if ss.current_embedding == "intfloat/multilingual-e5-large":
|
791 |
+
db_query_str = "query: " + query
|
792 |
+
else:
|
793 |
+
db_query_str = query
|
794 |
+
|
795 |
+
# DBから関連文書と出典を抽出
|
796 |
+
docs = ss.db.similarity_search(db_query_str, k=ss.similarity_search_k)
|
797 |
+
sources= "\n\n[Sources]\n" + '\n - '.join(list(set(doc.metadata['source'] for doc in docs if 'source' in doc.metadata)))
|
798 |
+
|
799 |
+
# Rinnaモデル向けの設定(抽出文書の改行コード修正)
|
800 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
801 |
+
for i in range(len(docs)):
|
802 |
+
docs[i].page_content = docs[i].page_content.strip().replace("\n", "<NL>")
|
803 |
+
|
804 |
+
# 回答の生成(最大3回の試行)
|
805 |
+
for _ in range(3):
|
806 |
+
result = ss.qa_chain({"input_documents": docs, "query": query})
|
807 |
+
result["output_text"] = result["output_text"].replace("<NL>", "\n").strip("...").strip("回答:").strip()
|
808 |
+
|
809 |
+
# result["output_text"]が空欄でない場合、メモリーを更新して返す
|
810 |
+
if result["output_text"] != "":
|
811 |
+
response = result["output_text"] + sources
|
812 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
813 |
+
return ss
|
814 |
+
else:
|
815 |
+
# 空欄の場合は直近の履歴を削除してやり直し
|
816 |
+
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2]
|
817 |
+
|
818 |
+
# 3回の試行後も空欄の場合
|
819 |
+
response = "3回試行しましたが、情報製生成できませんでした。"
|
820 |
+
if sources != "":
|
821 |
+
response += "参考文献の抽出には成功していますので、言語モデルを変えてお���しください。"
|
822 |
+
|
823 |
+
# ユーザーメッセージと AI メッセージの追加
|
824 |
+
ss.memory.chat_memory.add_user_message(query.replace("<NL>", "\n"))
|
825 |
+
ss.memory.chat_memory.add_ai_message(response)
|
826 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
|
827 |
+
return ss
|
828 |
+
|
829 |
+
# 回答を1文字ずつチャット画面に表示する
|
830 |
+
def show_response(ss: SessionState) -> str:
|
831 |
+
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
|
832 |
+
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
833 |
+
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
834 |
+
|
835 |
+
if response is None:
|
836 |
+
response = "回答を生成できませんでした。"
|
837 |
+
|
838 |
+
for character in response:
|
839 |
+
chat_history[-1][1] += character
|
840 |
+
time.sleep(0.05)
|
841 |
+
yield chat_history
|
842 |
+
|
843 |
+
with gr.Blocks() as demo:
|
844 |
+
|
845 |
+
# ユーザ別セッションメモリのインスタンス化(リロードでリセット)
|
846 |
+
ss = gr.State(SessionState())
|
847 |
+
|
848 |
+
# --------------------------------------
|
849 |
+
# API KEY をセット/クリアする関数
|
850 |
+
# --------------------------------------
|
851 |
+
def openai_api_setfn(openai_api_key) -> str:
|
852 |
+
if openai_api_key == "kikagaku":
|
853 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("kikagaku_demo")
|
854 |
+
status_message = "✅ キカガク専用DEMOへようこそ!APIキーを設定しました"
|
855 |
+
return status_message
|
856 |
+
elif not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
|
857 |
+
os.environ["OPENAI_API_KEY"] = ""
|
858 |
+
status_message = "❌ 有効なAPIキーを入力してください"
|
859 |
+
return status_message
|
860 |
+
else:
|
861 |
+
os.environ["OPENAI_API_KEY"] = openai_api_key
|
862 |
+
status_message = "✅ APIキーを設定しました"
|
863 |
+
return status_message
|
864 |
+
|
865 |
+
def openai_api_clsfn(ss) -> (str, str):
|
866 |
+
openai_api_key = ""
|
867 |
+
os.environ["OPENAI_API_KEY"] = ""
|
868 |
+
status_message = "✅ APIキーの削除が完了しました"
|
869 |
+
return status_message, ""
|
870 |
+
|
871 |
+
with gr.Tabs():
|
872 |
+
# --------------------------------------
|
873 |
+
# Setting Tab
|
874 |
+
# --------------------------------------
|
875 |
+
with gr.TabItem("1. LLM設定"):
|
876 |
+
with gr.Row():
|
877 |
+
model_id = gr.Dropdown(
|
878 |
+
choices=[
|
879 |
+
'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct',
|
880 |
+
'rinna/bilingual-gpt-neox-4b-instruction-sft',
|
881 |
+
'gpt-3.5-turbo',
|
882 |
+
],
|
883 |
+
value="gpt-3.5-turbo",
|
884 |
+
label='LLM model',
|
885 |
+
interactive=True,
|
886 |
+
)
|
887 |
+
with gr.Row():
|
888 |
+
embedding_id = gr.Dropdown(
|
889 |
+
choices=[
|
890 |
+
'intfloat/multilingual-e5-large',
|
891 |
+
'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
|
892 |
+
'oshizo/sbert-jsnli-luke-japanese-base-lite',
|
893 |
+
'text-embedding-ada-002',
|
894 |
+
"None"
|
895 |
+
],
|
896 |
+
value="text-embedding-ada-002",
|
897 |
+
label = 'Embedding model',
|
898 |
+
interactive=True,
|
899 |
+
)
|
900 |
+
with gr.Row():
|
901 |
+
with gr.Column(scale=19):
|
902 |
+
openai_api_key = gr.Textbox(label="OpenAI API Key (Optional)", interactive=True, type="password", value="", placeholder="Your OpenAI API Key for OpenAI models.", max_lines=1)
|
903 |
+
with gr.Column(scale=1):
|
904 |
+
openai_api_set = gr.Button(value="Set API KEY", size="sm")
|
905 |
+
openai_api_cls = gr.Button(value="Delete API KEY", size="sm")
|
906 |
+
|
907 |
+
# with gr.Row():
|
908 |
+
# reference_libs = gr.CheckboxGroup(choices=['LangChain', 'Gradio'], label="Reference Libraries", interactive=False)
|
909 |
+
|
910 |
+
# 詳細設定(折りたたみ)
|
911 |
+
with gr.Accordion(label="Advanced Setting", open=False):
|
912 |
+
with gr.Row():
|
913 |
+
with gr.Column():
|
914 |
+
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
|
915 |
+
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=True)
|
916 |
+
with gr.Column():
|
917 |
+
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
|
918 |
+
with gr.Column():
|
919 |
+
similarity_search_k = gr.Slider(label="similarity_search_k (OpenAI, HF)", minimum=1, maximum=10, step=1, value=3, interactive=True)
|
920 |
+
with gr.Column():
|
921 |
+
summarization_mode = gr.Radio(choices=['stuff', 'map_reduce'], label="Summarization mode", value='stuff', interactive=True)
|
922 |
+
with gr.Column():
|
923 |
+
min_length = gr.Slider(label="min_length (HF)", minimum=1, maximum=100, step=1, value=10, interactive=True)
|
924 |
+
with gr.Column():
|
925 |
+
max_new_tokens = gr.Slider(label="max_tokens(OpenAI), max_new_tokens(HF)", minimum=1, maximum=1024, step=1, value=256, interactive=True)
|
926 |
+
with gr.Column():
|
927 |
+
top_k = gr.Slider(label='top_k (HF)', minimum=1, maximum=100, step=1, value=40, interactive=True)
|
928 |
+
with gr.Column():
|
929 |
+
top_p = gr.Slider(label='top_p (HF)', minimum=0.01, maximum=0.99, step=0.01, value=0.92, interactive=True)
|
930 |
+
with gr.Column():
|
931 |
+
repetition_penalty = gr.Slider(label='repetition_penalty (HF)', minimum=0.5, maximum=2, step=0.1, value=1.2, interactive=True)
|
932 |
+
with gr.Column():
|
933 |
+
num_return_sequences = gr.Slider(label='num_return_sequences (HF)', minimum=1, maximum=20, step=1, value=3, interactive=True)
|
934 |
+
|
935 |
+
with gr.Row():
|
936 |
+
with gr.Column(scale=2):
|
937 |
+
config_btn = gr.Button(value="Configure")
|
938 |
+
with gr.Column(scale=13):
|
939 |
+
status_cfg = gr.Textbox(show_label=False, interactive=False, value="モデルを設定してください", container=False, max_lines=1)
|
940 |
+
|
941 |
+
# ボタン等のアクション設定
|
942 |
+
openai_api_set.click(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
|
943 |
+
openai_api_cls.click(openai_api_clsfn, inputs=[openai_api_key], outputs=[status_cfg, openai_api_key], show_progress="full")
|
944 |
+
openai_api_key.submit(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
|
945 |
+
config_btn.click(
|
946 |
+
fn = load_models,
|
947 |
+
inputs = [ss, model_id, embedding_id, openai_api_key, load_in_8bit, verbose, temperature, \
|
948 |
+
similarity_search_k, summarization_mode, \
|
949 |
+
min_length, max_new_tokens, top_k, top_p, repetition_penalty, num_return_sequences],
|
950 |
+
outputs = [ss, status_cfg],
|
951 |
+
queue = True,
|
952 |
+
show_progress = "full"
|
953 |
+
)
|
954 |
+
|
955 |
+
# --------------------------------------
|
956 |
+
# Reference Tab
|
957 |
+
# --------------------------------------
|
958 |
+
with gr.TabItem("2. References"):
|
959 |
+
urls = gr.TextArea(
|
960 |
+
max_lines = 60,
|
961 |
+
show_label=False,
|
962 |
+
info = "List any reference URLs for Q&A retrieval.",
|
963 |
+
placeholder = "https://blog.kikagaku.co.jp/deep-learning-transformer\nhttps://note.com/elyza/n/na405acaca130",
|
964 |
+
interactive=True,
|
965 |
+
)
|
966 |
+
|
967 |
+
with gr.Row():
|
968 |
+
pdf_paths = gr.File(label="PDFs", height=150, min_width=60, scale=7, file_types=[".pdf"], file_count="multiple", interactive=True)
|
969 |
+
header_lim = gr.Number(label="Header (pt)", step=1, value=792, precision=0, min_width=70, scale=1, interactive=True)
|
970 |
+
footer_lim = gr.Number(label="Footer (pt)", step=1, value=0, precision=0, min_width=70, scale=1, interactive=True)
|
971 |
+
pdf_ref = gr.Textbox(show_label=False, value="A4 Size:\n(下)0-792pt(上)\n *28.35pt/cm", container=False, scale=1, interactive=False)
|
972 |
+
|
973 |
+
with gr.Row():
|
974 |
+
ref_set_btn = gr.Button(value="コンテンツ登録", scale=1)
|
975 |
+
ref_clear_btn = gr.Button(value="登録データ削除", scale=1)
|
976 |
+
status_ref = gr.Textbox(show_label=False, interactive=False, value="参照データ未登録", container=False, max_lines=1, scale=18)
|
977 |
+
|
978 |
+
ref_set_btn.click(fn=embed_ref, inputs=[ss, urls, pdf_paths, header_lim, footer_lim], outputs=[ss, status_ref], queue=True, show_progress="full")
|
979 |
+
ref_clear_btn.click(fn=clear_db, inputs=[ss], outputs=[ss, status_ref], show_progress="full")
|
980 |
+
|
981 |
+
# --------------------------------------
|
982 |
+
# Chatbot Tab
|
983 |
+
# --------------------------------------
|
984 |
+
with gr.TabItem("3. Q&A Chat"):
|
985 |
+
chat_history = gr.Chatbot([], elem_id="chatbot", avatar_images=["bear.png", "penguin.png"],)
|
986 |
+
with gr.Row():
|
987 |
+
with gr.Column(scale=95):
|
988 |
+
query = gr.Textbox(
|
989 |
+
show_label=False,
|
990 |
+
placeholder="Send a message with [Shift]+[Enter] key.",
|
991 |
+
lines=4,
|
992 |
+
container=False,
|
993 |
+
autofocus=True,
|
994 |
+
interactive=True,
|
995 |
+
)
|
996 |
+
with gr.Column(scale=5):
|
997 |
+
with gr.Row():
|
998 |
+
qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=True)
|
999 |
+
web_flag = gr.Checkbox(label="Web Search", value=False, min_width=60, interactive=True)
|
1000 |
+
with gr.Row():
|
1001 |
+
query_send_btn = gr.Button(value="▶")
|
1002 |
+
|
1003 |
+
# gr.Examples(["機械学習について説明してください"], inputs=[query])
|
1004 |
+
query.submit(
|
1005 |
+
user, [ss, query], [ss, chat_history]
|
1006 |
+
).then(
|
1007 |
+
bot, [ss, query, qa_flag, web_flag, summarization_mode], [ss, query]
|
1008 |
+
).then(
|
1009 |
+
show_response, [ss], [chat_history]
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
query_send_btn.click(
|
1013 |
+
user, [ss, query], [ss, chat_history]
|
1014 |
+
).then(
|
1015 |
+
bot, [ss, query, qa_flag, web_flag, summarization_mode], [ss, query]
|
1016 |
+
).then(
|
1017 |
+
show_response, [ss], [chat_history]
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
if __name__ == "__main__":
|
1021 |
+
demo.queue(concurrency_count=5)
|
1022 |
+
demo.launch(debug=True)
|
bear.png
ADDED
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
chromium-driver
|
penguin.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.22.0
|
2 |
+
beautifulsoup4==4.11.2
|
3 |
+
bitsandbytes==0.41.1
|
4 |
+
transformers==4.30.0
|
5 |
+
sentence-transformers==2.2.2
|
6 |
+
sentencepiece==0.1.99
|
7 |
+
langchain==0.0.281
|
8 |
+
xformers==0.0.21
|
9 |
+
chromadb==0.4.8
|
10 |
+
gradio==3.42.0
|
11 |
+
gradio_client==0.5.0
|
12 |
+
openai==0.28.0
|
13 |
+
tiktoken==0.4.0
|
14 |
+
fugashi==1.3.0
|
15 |
+
ipadic==1.0.0
|
16 |
+
unstructured==0.10.12
|
17 |
+
selenium==4.12.0
|
18 |
+
pypdf==3.15.5
|
19 |
+
Cython==0.29.36
|
20 |
+
numpy==1.23.5
|
21 |
+
pandas==1.5.3
|
22 |
+
chromedriver-autoinstaller
|
23 |
+
chromedriver-binary
|