Spaces:
Running
Running
import copy | |
import datetime | |
import json | |
import os | |
import re | |
import string | |
import time | |
from typing import Optional, Any | |
import gradio as gr | |
import openai | |
import google.generativeai as genai | |
# Set up LLM APIs | |
llm_api_options = ['gemini-pro', 'gemini-1.5-flash', 'gpt-3.5-turbo-1106'] | |
def query_gpt_model( | |
prompt: str, | |
llm: str = 'gpt-3.5-turbo-1106', | |
client: Optional[Any] = None, | |
temperature: float = 0.0, | |
max_decode_steps: int = 512, | |
seconds_to_reset_tokens: float = 30.0, | |
) -> str: | |
while True: | |
try: | |
raw_response = client.chat.completions.with_raw_response.create( | |
model=llm, | |
max_tokens=max_decode_steps, | |
temperature=temperature, | |
messages=[ | |
{'role': 'user', 'content': prompt}, | |
] | |
) | |
completion = raw_response.parse() | |
return completion.choices[0].message.content | |
except openai.RateLimitError as e: | |
print(f'{datetime.datetime.now()}: query_gpt_model: RateLimitError {e.message}: {e}') | |
time.sleep(seconds_to_reset_tokens) | |
except openai.APIError as e: | |
print(f'{datetime.datetime.now()}: query_gpt_model: APIError {e.message}: {e}') | |
print(f'{datetime.datetime.now()}: query_gpt_model: Retrying after 5 seconds...') | |
time.sleep(5) | |
safety_settings=[ | |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, | |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"} | |
] | |
def query_gemini_model( | |
prompt: str, | |
llm: str = 'gemini-pro', | |
client: Optional[Any] = None, | |
retries: int = 10, | |
) -> str: | |
del client | |
model = genai.GenerativeModel(llm) | |
while True and retries > 0: | |
try: | |
response = model.generate_content(prompt, safety_settings=safety_settings) | |
text_response = response.text.replace("**", "") | |
return text_response | |
except Exception as e: | |
print(f'{datetime.datetime.now()}: query_gemini_model: Error: {e}') | |
print(f'{datetime.datetime.now()}: query_gemini_model: Retrying after 5 seconds...') | |
retries -= 1 | |
time.sleep(5) | |
def query_model( | |
prompt: str, | |
model_name: str = 'gemini-pro', | |
client: Optional[Any] = None, | |
) -> str: | |
model_type = model_name.split('-')[0] | |
if model_type == "gpt": | |
return query_gpt_model(prompt, llm=model_name, client=client) | |
elif model_type == "gemini": | |
return query_gemini_model(prompt, llm=model_name, client=client) | |
else: | |
raise ValueError('Unexpected model_name: ', model_name) | |
# Load QuALITY dataset | |
_ONE2ONE_FIELDS = ( | |
'article', | |
'article_id', | |
'set_unique_id', | |
'writer_id', | |
'source', | |
'title', | |
'topic', | |
'url', | |
'writer_id', | |
'author', | |
) | |
quality_dev = [] | |
with open('QuALITY.v1.0.1.htmlstripped.dev', 'r') as f: | |
for line in f.readlines(): | |
j = json.loads(line) | |
fields = {k: j[k] for k in _ONE2ONE_FIELDS} | |
fields.update({ | |
'questions': [q['question'] for q in j['questions']], | |
'question_ids': [q['question_unique_id'] for q in j['questions']], | |
'difficults': [q['difficult'] for q in j['questions']], | |
'options': [q['options'] for q in j['questions']], | |
}) | |
fields.update({ | |
'gold_labels': [q['gold_label'] for q in j['questions']], | |
'writer_labels': [q['writer_label'] for q in j['questions']], | |
}) | |
quality_dev.append(fields) | |
# Helper functions | |
all_lowercase_letters = string.ascii_lowercase # "abcd...xyz" | |
bracketed_lowercase_letters_set = set( | |
[f"({l})" for l in all_lowercase_letters] | |
) # {"(a)", ...} | |
bracketed_uppercase_letters_set = set( | |
[f"({l.upper()})" for l in all_lowercase_letters] | |
) # {"(a)", ...} | |
choices = ['(A)', '(B)', '(C)', '(D)'] | |
def get_index_from_symbol(answer): | |
"""Get the index from the letter symbols A, B, C, D, to extract answer texts. | |
Args: | |
answer (str): the string of answer like "(B)". | |
Returns: | |
index (int): how far the given choice is from "a", like 1 for answer "(B)". | |
""" | |
answer = str(answer).lower() | |
# extract the choice letter from within bracket | |
if answer in bracketed_lowercase_letters_set: | |
answer = re.findall(r".*?", answer)[0][1] | |
index = ord(answer) - ord("a") | |
return index | |
def count_words(text): | |
"""Simple word counting.""" | |
return len(text.split()) | |
def quality_gutenberg_parser(raw_article): | |
"""Parse Gutenberg articles in the QuALITY dataset.""" | |
lines = [] | |
previous_line = None | |
for i, line in enumerate(raw_article.split('\n')): | |
line = line.strip() | |
original_line = line | |
if line == '': | |
if previous_line == '': | |
line = '\n' | |
else: | |
previous_line = original_line | |
continue | |
previous_line = original_line | |
lines.append(line) | |
return ' '.join(lines) | |
# ReadAgent (1) Episode Pagination | |
prompt_pagination_template = """ | |
You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage. | |
Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text. | |
Please choose one label that it is natural to break reading. | |
Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc. | |
Please answer the break point label and explain. | |
For example, if <57> is a good point to break, answer with \"Break point: <57>\n Because ...\" | |
Passage: | |
{0} | |
{1} | |
{2} | |
""" | |
def parse_pause_point(text): | |
text = text.strip("Break point: ") | |
if text[0] != '<': | |
return None | |
for i, c in enumerate(text): | |
if c == '>': | |
if text[1:i].isnumeric(): | |
return int(text[1:i]) | |
else: | |
return None | |
return None | |
def quality_pagination(example, | |
model_name='gemini-pro', | |
client=None, | |
word_limit=600, | |
start_threshold=280, | |
max_retires=10, | |
verbose=True, | |
allow_fallback_to_last=True): | |
article = example['article'] | |
title = example['title'] | |
text_output = f"[Pagination][Article {title}]" + '\n\n' | |
paragraphs = quality_gutenberg_parser(article).split('\n') | |
i = 0 | |
pages = [] | |
while i < len(paragraphs): | |
preceding = "" if i == 0 else "...\n" + '\n'.join(pages[-1]) | |
passage = [paragraphs[i]] | |
wcount = count_words(paragraphs[i]) | |
j = i + 1 | |
while wcount < word_limit and j < len(paragraphs): | |
wcount += count_words(paragraphs[j]) | |
if wcount >= start_threshold: | |
passage.append(f"<{j}>") | |
passage.append(paragraphs[j]) | |
j += 1 | |
passage.append(f"<{j}>") | |
end_tag = "" if j == len(paragraphs) else paragraphs[j] + "\n..." | |
pause_point = None | |
if wcount < 350: | |
pause_point = len(paragraphs) | |
else: | |
prompt = prompt_pagination_template.format(preceding, '\n'.join(passage), end_tag) | |
response = query_model(prompt=prompt, model_name=model_name, client=client).strip() | |
pause_point = parse_pause_point(response) | |
if pause_point and (pause_point <= i or pause_point > j): | |
# process += f"prompt:\n{prompt},\nresponse:\n{response}\n" | |
# process += f"i:{i} j:{j} pause_point:{pause_point}" + '\n' | |
pause_point = None | |
if pause_point is None: | |
if allow_fallback_to_last: | |
pause_point = j | |
else: | |
raise ValueError(f"prompt:\n{prompt},\nresponse:\n{response}\n") | |
page = paragraphs[i:pause_point] | |
pages.append(page) | |
text_output += f"Paragraph {i}-{pause_point-1}: {page}\n\n" | |
i = pause_point | |
text_output += f"\n\n[Pagination] Done with {len(pages)} pages" | |
return pages, text_output | |
# ReadAgent (2) Memory Gisting | |
prompt_shorten_template = """ | |
Please shorten the following passage. | |
Just give me a shortened version. DO NOT explain your reason. | |
Passage: | |
{} | |
""" | |
def quality_gisting(example, pages, model_name, client=None, word_limit=600, start_threshold=280, verbose=True): | |
article = example['article'] | |
title = example['title'] | |
word_count = count_words(article) | |
text_output = f"[Gisting][Article {title}], {word_count} words\n\n" | |
shortened_pages = [] | |
for i, page in enumerate(pages): | |
prompt = prompt_shorten_template.format('\n'.join(page)) | |
response = query_model(prompt, model_name, client) | |
shortened_text = response.strip() | |
shortened_pages.append(shortened_text) | |
text_output += "[gist] page {}: {}\n\n".format(i, shortened_text) | |
shortened_article = '\n'.join(shortened_pages) | |
gist_word_count = count_words(shortened_article) | |
text_output += '\n\n' + f"Shortened article:\n{shortened_article}\n\n" | |
output = copy.deepcopy(example) | |
output.update({'title': title, 'word_count': word_count, 'gist_word_count': gist_word_count, 'shortened_pages': shortened_pages, 'pages': pages}) | |
text_output += f"\n\ncompression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})" | |
return output, text_output | |
# ReadAgent (3) Look-Up | |
prompt_lookup_template = """ | |
The following text is what you remembered from reading an article and a multiple choice question related to it. | |
You may read 1 to 6 page(s) of the article again to refresh your memory to prepare yourselve for the question. | |
Please respond with which page(s) you would like to read. | |
For example, if your only need to read Page 8, respond with \"I want to look up Page [8] to ...\"; | |
if your would like to read Page 7 and 12, respond with \"I want to look up Page [7, 12] to ...\"; | |
if your would like to read Page 2, 3, 7, 15 and 18, respond with \"I want to look up Page [2, 3, 7, 15, 18] to ...\". | |
if your would like to read Page 3, 4, 5, 12, 13 and 16, respond with \"I want to look up Page [3, 3, 4, 12, 13, 16] to ...\". | |
DO NOT select more pages if you don't need to. | |
DO NOT answer the question yet. | |
Text: | |
{} | |
Question: | |
{} | |
{} | |
Take a deep breath and tell me: Which page(s) would you like to read again? | |
""" | |
prompt_answer_template = """ | |
Read the following article and answer a multiple choice question. | |
For example, if (C) is correct, answer with \"Answer: (C) ...\" | |
Article: | |
{} | |
Question: | |
{} | |
{} | |
""" | |
def quality_parallel_lookup(example, model_name, client, verbose=True): | |
preprocessed_pages = example['pages'] | |
article = example['article'] | |
title = example['title'] | |
word_count = example['word_count'] | |
gist_word_count = example['gist_word_count'] | |
pages = example['pages'] | |
shortened_pages = example['shortened_pages'] | |
questions = example['questions'] | |
options = example['options'] | |
gold_labels = example['gold_labels'] # numerical [1, 2, 3, 4] | |
text_outputs = [f"[Look-Up][Article {title}] {word_count} words"] | |
model_choices = [] | |
lookup_page_ids = [] | |
shortened_pages_pidx = [] | |
for i, shortened_text in enumerate(shortened_pages): | |
shortened_pages_pidx.append("\n".format(i) + shortened_text) | |
shortened_article = '\n'.join(shortened_pages_pidx) | |
expanded_gist_word_counts = [] | |
for i, label in enumerate(gold_labels): | |
# only test the first question for demo | |
if i != 1: | |
continue | |
q = questions[i] | |
text_output = f"question {i}: {q}" + '\n\n' | |
options_i = [f"{ol} {o}" for ol, o in zip(choices, options[i])] | |
text_output += "options: " + "\n".join(options_i) | |
text_output += '\n\n' | |
prompt_lookup = prompt_lookup_template.format(shortened_article, q, '\n'.join(options_i)) | |
page_ids = [] | |
response = query_model(prompt=prompt_lookup, model_name=model_name, client=client).strip() | |
try: start = response.index('[') | |
except ValueError: start = len(response) | |
try: end = response.index(']') | |
except ValueError: end = 0 | |
if start < end: | |
page_ids_str = response[start+1:end].split(',') | |
page_ids = [] | |
for p in page_ids_str: | |
if p.strip().isnumeric(): | |
page_id = int(p) | |
if page_id < 0 or page_id >= len(pages): | |
text_output += f"Skip invalid page number: {page_id}\n\n" | |
else: | |
page_ids.append(page_id) | |
text_output += "Model chose to look up page {}\n\n".format(page_ids) | |
# Memory expansion after look-up, replacing the target shortened page with the original page | |
expanded_shortened_pages = shortened_pages[:] | |
if len(page_ids) > 0: | |
for page_id in page_ids: | |
expanded_shortened_pages[page_id] = '\n'.join(pages[page_id]) | |
expanded_shortened_article = '\n'.join(expanded_shortened_pages) | |
expanded_gist_word_count = count_words(expanded_shortened_article) | |
text_output += "Expanded shortened article:\n" + expanded_shortened_article + '\n\n' | |
prompt_answer = prompt_answer_template.format(expanded_shortened_article, q, '\n'.join(options_i)) | |
model_choice = None | |
response = query_model(prompt=prompt_answer, model_name=model_name, client=client) | |
response = response.strip() | |
for j, choice in enumerate(choices): | |
if response.startswith(f"Answer: {choice}") or response.startswith(f"Answer: {choice[1]}"): | |
model_choice = j+1 | |
break | |
is_correct = 1 if model_choice == label else 0 | |
text_output += f"reference answer: {choices[label]}, model prediction: {choices[model_choice]}, is_correct: {is_correct}" + '\n\n' | |
text_output += f"compression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})" + '\n\n' | |
text_output += f"compression rate after look-up {round(100.0 - expanded_gist_word_count/word_count*100, 2)}% ({expanded_gist_word_count}/{word_count})" + '\n\n' | |
text_output += '\n\n' | |
text_outputs.append(text_output) | |
return text_outputs | |
# ReadAgent | |
def query_model_with_quality( | |
index: int, | |
model_name: str = 'gemini-pro', | |
api_key: Optional[str] = None, | |
): | |
# setup api key first | |
client = None | |
model_type = model_name.split('-')[0] | |
if model_type == "gpt": | |
# api_key = os.environ.get('OPEN_AI_KEY') | |
client = openai.OpenAI(api_key=api_key) | |
elif model_type == "gemini": | |
# api_key = os.environ.get('GEMINI_API_KEY') | |
genai.configure(api_key=api_key) | |
example = quality_dev[index] | |
article = f"[Title: {example['title']}]\n\n{example['article']}" | |
pages, pagination = quality_pagination(example, model_name, client) | |
print('Finish Pagination.') | |
example_with_gists, gisting = quality_gisting(example, pages, model_name, client) | |
print('Finish Gisting.') | |
answers = quality_parallel_lookup(example_with_gists, model_name, client) | |
# return prompt_pagination_template, pagination, prompt_shorten_template, gisting, prompt_lookup_template, '\n\n'.join(answers) | |
return article, pagination, gisting, '\n\n'.join(answers) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# A Human-Inspired Reading Agent with Gist Memory of Very Long Contexts | |
[[website]](https://read-agent.github.io/) | |
[[view on huggingface]](https://huggingface.co/spaces/ReadAgent/read-agent) | |
[[arXiv]](https://arxiv.org/abs/2402.09727) | |
[[OpenReview]](https://openreview.net/forum?id=OTmcsyEO5G) | |
![teaser](/file=./asset/teaser.png) | |
The demo below showcases a version of the ReadAgent algorithm, which is nspired by how humans interactively read long documents. | |
We implement ReadAgent as a simple prompting system that uses the advanced language capabilities of LLMs to (1) decide what content to store together in a memory episode (**Episode Pagination**), (2) compress those memory episodes into short episodic memories called gist memories (**Memory Gisting**), and (3) take actions to look up passages in the original text if ReadAgent needs to remind itself of relevant details to complete a task (**Parallel Lookup and QA**) | |
This demo can handle long-document reading comprehension tasks ([QuALITY](https://arxiv.org/abs/2112.08608); max 6,000 words) efficiently. | |
To get started, you can choose an index of QuALITY dataset. | |
This demo uses Gemini API or OpenAI API so it requires the corresponding API key. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
llm_options = gr.Radio(llm_api_options, label="Backend LLM API", value='gemini-pro') | |
llm_api_key = gr.Textbox( | |
label="Paste your OpenAI API key (sk-...) or Gemini API key", | |
lines=1, | |
type="password", | |
) | |
index = gr.Dropdown(list(range(len(quality_dev))), value=13, label="QuALITY Index") | |
button = gr.Button("Execute") | |
original_article = gr.Textbox(label="Original Article", lines=20) | |
# prompt_pagination = gr.Textbox(label="Episode Pagination Prompt Template", lines=5) | |
pagination_results = gr.Textbox(label="(1) Episode Pagination", lines=20) | |
# prompt_gisting = gr.Textbox(label="Memory Gisting Prompt Template", lines=5) | |
gisting_results = gr.Textbox(label="(2) Memory Gisting", lines=20) | |
# prompt_lookup = gr.Textbox(label="Parallel Lookup Prompt Template", lines=5) | |
lookup_qa_results = gr.Textbox(label="(3) Parallel Lookup and QA", lines=20) | |
button.click( | |
fn=query_model_with_quality, | |
inputs=[ | |
index, | |
llm_options, | |
llm_api_key, | |
], | |
outputs=[ | |
# prompt_pagination, pagination_results, | |
# prompt_gisting, gisting_results, | |
# prompt_lookup, lookup_qa_results, | |
original_article, | |
pagination_results, | |
gisting_results, | |
lookup_qa_results, | |
] | |
) | |
if __name__ == '__main__': | |
demo.launch(allowed_paths=['./asset/teaser.png']) | |