Spaces:
Running
Running
import gradio as gr | |
import paperqa | |
import pickle | |
from pathlib import Path | |
import requests | |
import zipfile | |
import io | |
import tempfile | |
import os | |
docs = None | |
def request_pathname(files): | |
if files is None: | |
return [[]] | |
return [[file.name, file.name.split('/')[-1], None] for file in files], [[len(files), 0]] | |
def validate_dataset(dataset, openapi): | |
global docs | |
docs = None # clear it out if dataset is modified | |
docs_ready = dataset.iloc[-1, 0] != "" | |
if docs_ready and type(openapi) is str and len(openapi) > 0: | |
return "✨Ready✨" | |
elif docs_ready: | |
return "⚠️Waiting for key..." | |
elif type(openapi) is str and len(openapi) > 0: | |
return "⚠️Waiting for documents..." | |
else: | |
return "⚠️Waiting for documents and key..." | |
def make_stats(docs): | |
return [[len(docs.doc_previews), sum([x[0] for x in docs.doc_previews])]] | |
# , progress=gr.Progress()): | |
def do_ask(question, button, openapi, dataset, length, do_marg, k, max_sources, docs): | |
passages = "" | |
docs_ready = dataset.iloc[-1, 0] != "" | |
if button == "✨Ready✨" and type(openapi) is str and len(openapi) > 0 and docs_ready: | |
if docs is None: # don't want to rebuild index if it's already built | |
import os | |
os.environ['OPENAI_API_KEY'] = openapi.strip() | |
docs = paperqa.Docs() | |
# dataset is pandas dataframe | |
for _, row in dataset.iterrows(): | |
try: | |
docs.add(row['filepath'], row['citation string'], | |
key=row['key'], disable_check=True) | |
yield "", "", "", docs, make_stats(docs) | |
except Exception as e: | |
pass | |
else: | |
yield "", "", "", docs, [[0, 0]] | |
#progress(0, "Building Index...") | |
docs._build_faiss_index() | |
#progress(0.25, "Querying...") | |
for i, result in enumerate(docs.query_gen(question, | |
length_prompt=f'use {length:d} words', | |
marginal_relevance=do_marg, | |
k=k, max_sources=max_sources)): | |
#progress(0.25 + 0.1 * i, "Generating Context" + str(i)) | |
yield result.formatted_answer, result.context, passages, docs, make_stats(docs) | |
#progress(1.0, "Done!") | |
# format the passages | |
for i, (key, passage) in enumerate(result.passages.items()): | |
passages += f'{i+1}. {key}\n\n >{passage} \n\n' | |
yield result.formatted_answer, result.context, passages, docs, make_stats(docs) | |
def download_repo(gh_repo, pbar=gr.Progress()): | |
# download zipped version of repo | |
r = requests.get(f'https://api.github.com/repos/{gh_repo}/zipball') | |
files = [] | |
if r.status_code == 200: | |
pbar(1, 'Downloaded') | |
# iterate through files in zip | |
with zipfile.ZipFile(io.BytesIO(r.content)) as z: | |
for i, f in enumerate(z.namelist()): | |
# skip directories | |
if f.endswith('/'): | |
continue | |
# try to read as plaintext (skip binary files) | |
try: | |
text = z.read(f).decode('utf-8') | |
except UnicodeDecodeError: | |
continue | |
# check if it's bigger than 1MB or smaller than 10 bytes | |
if len(text) > 1e6 or len(text) < 10: | |
continue | |
# have to save to temporary file so we have a path | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
tmp.write(text.encode('utf-8')) | |
tmp.flush() | |
path = tmp.name | |
# strip off the first directory of f | |
rel_path = '/'.join(f.split('/')[1:]) | |
key = os.path.basename(f) | |
citation = f'[{rel_path}](https://github.com/{gh_repo}/tree/main/{rel_path})' | |
files.append([path, citation, key]) | |
yield files, [[len(files), 0]] | |
pbar(int((i+1)/len(z.namelist()) * 99), | |
f'Added {f}') | |
pbar(100, 'Done') | |
else: | |
raise ValueError('Unknown Github Repo') | |
with gr.Blocks() as demo: | |
docs = gr.State(None) | |
openai_api_key = gr.State('') | |
gr.Markdown(f""" | |
# Document Question and Answer (v{paperqa.__version__}) | |
*By Andrew White ([@andrewwhite01](https://twitter.com/andrewwhite01))* | |
This tool will enable asking questions of your uploaded text, PDF documents, | |
or scrape github repos. | |
It uses OpenAI's GPT models and thus you must enter your API key below. This | |
tool is under active development and currently uses many tokens - up to 10,000 | |
for a single query. That is $0.10-0.20 per query, so please be careful! | |
* [PaperQA](https://github.com/whitead/paper-qa) is the code used to build this tool. | |
* [langchain](https://github.com/hwchase17/langchain) is the main library this tool utilizes. | |
1. Enter API Key ([What is that?](https://platform.openai.com/account/api-keys)) | |
2. Upload your documents and modify citation strings if you want (to look prettier in answer) | |
""") | |
openai_api_key = gr.Textbox( | |
label="OpenAI API Key", placeholder="sk-...", type="password") | |
with gr.Tab('File Upload'): | |
uploaded_files = gr.File( | |
label="Your Documents Upload (PDF or txt)", file_count="multiple", ) | |
with gr.Tab('Github Repo'): | |
gh_repo = gr.Textbox( | |
label="Github Repo", placeholder="whitead/paper-qa") | |
download = gr.Button("Download Repo") | |
with gr.Accordion("See Docs:", open=False): | |
dataset = gr.Dataframe( | |
headers=["filepath", "citation string", "key"], | |
datatype=["str", "str", "str"], | |
col_count=(3, "fixed"), | |
interactive=True, | |
label="Documents and Citations", | |
overflow_row_behaviour='paginate', | |
max_rows=5 | |
) | |
buildb = gr.Textbox("⚠️Waiting for documents and key...", | |
label="Status", interactive=False, show_label=True, | |
max_lines=1) | |
stats = gr.Dataframe(headers=['Docs', 'Chunks'], | |
datatype=['number', 'number'], | |
col_count=(2, "fixed"), | |
interactive=False, | |
label="Doc Stats") | |
openai_api_key.change(validate_dataset, inputs=[ | |
dataset, openai_api_key], outputs=[buildb]) | |
dataset.change(validate_dataset, inputs=[ | |
dataset, openai_api_key], outputs=[buildb]) | |
uploaded_files.change(request_pathname, inputs=[ | |
uploaded_files], outputs=[dataset, stats]) | |
download.click(fn=download_repo, inputs=[ | |
gh_repo], outputs=[dataset, stats]) | |
query = gr.Textbox( | |
placeholder="Enter your question here...", label="Question") | |
with gr.Row(): | |
length = gr.Slider(25, 200, value=100, step=5, | |
label='Words in answer') | |
marg = gr.Checkbox(True, label='Max marginal relevance') | |
k = gr.Slider(1, 20, value=10, step=1, | |
label='Chunks to examine') | |
sources = gr.Slider(1, 10, value=5, step=1, | |
label='Contexts to include') | |
ask = gr.Button("Ask Question") | |
answer = gr.Markdown(label="Answer") | |
with gr.Accordion("Context", open=True): | |
context = gr.Markdown(label="Context") | |
with gr.Accordion("Raw Text", open=False): | |
passages = gr.Markdown(label="Passages") | |
ask.click(fn=do_ask, inputs=[query, buildb, | |
openai_api_key, dataset, | |
length, marg, k, sources, | |
docs], outputs=[answer, context, passages, docs, stats]) | |
demo.queue(concurrency_count=20) | |
demo.launch(show_error=True) | |