|
import streamlit as st |
|
from dotenv import load_dotenv |
|
import os |
|
from PyPDF2 import PdfReader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.prompts import PromptTemplate |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.output_parsers import StrOutputParser |
|
|
|
st.title('LLM - Retrieval Augmented Generation') |
|
|
|
model_names = ['tiiuae/falcon-7b-instruct', |
|
'google/gemma-2-2b', |
|
'mistralai/Mistral-7B-v0.1'] |
|
|
|
api_urls = ['https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct', |
|
'https://api-inference.huggingface.co/models/google/gemma-2-2b', |
|
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"] |
|
|
|
model_dict = dict(zip(model_names, api_urls)) |
|
|
|
|
|
pdf = st.file_uploader(label='Upload PDF') |
|
|
|
|
|
start_metatag = "<s>[INST]" |
|
default_template = """You are an assistant for question-answering tasks. |
|
Use the following pieces of retrieved context to answer the question. |
|
If you don't know the answer, say that you don't know. |
|
Never stop generating mid-sentence. |
|
|
|
Question: {question} |
|
Context: {context} |
|
Answer: |
|
""" |
|
end_metatag = "[/INST]" |
|
|
|
with st.sidebar: |
|
|
|
st.write('# Retrieval parameters') |
|
chunk_size = st.number_input(label='Chunk size', value=250, step=10) |
|
chunk_overlap = st.number_input(label='Chunk overlap', value=50, step=10) |
|
|
|
st.write('# Prompt') |
|
rag_template = st.text_area(label='Prompt template', value=default_template, height=250) |
|
|
|
st.write('# LLM parameters') |
|
model = st.selectbox(label='Model', options=model_names, index=0) |
|
temperature = st.slider(label='Model Temperature', min_value=0.1, max_value=float(10), value=float(1), step=0.1) |
|
|
|
|
|
template = start_metatag + '\n\n' + rag_template + '\n\n' + end_metatag |
|
|
|
|
|
question = st.text_input(label='Question') |
|
|
|
def authenticate(): |
|
|
|
try: |
|
st.write('Authenticated with HuggingFace:', |
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] == st.secrets["HUGGINGFACEHUB_API_TOKEN"]) |
|
except: |
|
st.write('Cannot find HugginFace API token. Ensure it is located in .streamlit/secrets.toml') |
|
|
|
def load_pdf(pdf): |
|
|
|
reader = PdfReader(pdf) |
|
|
|
page_limit = len(reader.pages) |
|
|
|
if page_limit is None: |
|
page_limit=len(reader.pages) |
|
|
|
text = "" |
|
|
|
for i in range(page_limit): |
|
|
|
page_text = reader.pages[i].extract_text() |
|
|
|
text += page_text |
|
|
|
|
|
|
|
|
|
return text |
|
|
|
def split_text(text, chunk_size=400, chunk_overlap=20): |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
separators=["\n\n", "\n", " ", ""] |
|
) |
|
|
|
|
|
chunks = text_splitter.split_text(text) |
|
|
|
return chunks |
|
|
|
def store_text(chunks): |
|
|
|
|
|
embeddings = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-large') |
|
|
|
|
|
vectorstore = FAISS.from_texts(texts=chunks, embedding=embeddings) |
|
|
|
return vectorstore |
|
|
|
@st.cache_resource |
|
def load_split_store(pdf, chunk_size, chunk_overlap): |
|
|
|
|
|
text = load_pdf(pdf=pdf) |
|
chunks = split_text(text, chunk_size, chunk_overlap) |
|
vectorstore = store_text(chunks) |
|
|
|
return vectorstore |
|
|
|
def format_docs(docs): |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
@st.cache_resource |
|
def instantiate_llm(model, temperature): |
|
|
|
|
|
llm = HuggingFaceHub( |
|
repo_id=model, |
|
model_kwargs={ |
|
'temperature':temperature, |
|
|
|
|
|
|
|
} |
|
) |
|
|
|
llm.client.api_url = model_dict[str(model)] |
|
|
|
return llm |
|
|
|
@st.cache_resource |
|
def instantiate_prompt(_llm, template=template): |
|
|
|
|
|
prompt = PromptTemplate( |
|
template=template, |
|
llm=_llm, |
|
input_variables=['question', 'context'] |
|
) |
|
return prompt |
|
|
|
def main(): |
|
|
|
|
|
authenticate() |
|
|
|
|
|
llm = instantiate_llm(model, temperature) |
|
|
|
|
|
prompt = instantiate_prompt(_llm=llm) |
|
|
|
|
|
if pdf is not None: |
|
|
|
|
|
vectorstore = load_split_store(pdf, chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
|
st.write('PDF processed') |
|
|
|
|
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
retrieval_chain = ( |
|
retriever | format_docs |
|
) |
|
|
|
|
|
generation_chain = ( |
|
{"context": retriever | format_docs, "question": RunnablePassthrough()} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
if st.button(label='Ask question'): |
|
with st.spinner('Processing'): |
|
|
|
|
|
st.write('# Context') |
|
st.write(retrieval_chain.invoke(question)) |
|
|
|
|
|
st.write('# Answer') |
|
st.write(generation_chain.invoke(question)) |
|
|
|
|
|
|
|
if __name__=='__main__': |
|
main() |
|
|