Spaces:
Sleeping
Sleeping
import json # to work with JSON | |
import threading # for threading | |
import time # for better HCI | |
import datasets # to load the dataset | |
import faiss # to create an index | |
import gradio # for the interface | |
import numpy # to work with vectors | |
import pandas # to work with pandas | |
import sentence_transformers # to load an embedding model | |
import spaces # for GPU | |
import transformers # to load an LLM | |
# Constants | |
GREETING = ( | |
"Howdy! I'm an AI agent that uses a [retrieval-augmented generation](" | |
"https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the " | |
"[Design Research Collective](https://cmudrc.github.io/). And the best part is that I always cite my sources! What" | |
" can I tell you about today?" | |
) | |
EXAMPLE_QUERIES = [ | |
"Tell me about new research at the intersection of additive manufacturing and machine learning", | |
"What is a physics-informed neural network and what can it be used for?", | |
"What can agent-based models do about climate change?", | |
] | |
EMBEDDING_MODEL_NAME = "allenai-specter" | |
LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct" | |
# Load the dataset and convert to pandas | |
data = datasets.load_dataset("ccm/publications")["train"].to_pandas() | |
# Filter out any publications without an abstract | |
abstract_is_null = [ | |
'"abstract": null' in json.dumps(bibdict) for bibdict in data["bib_dict"].values | |
] | |
data = data[~pandas.Series(abstract_is_null)] | |
data.reset_index(inplace=True) | |
# Create a FAISS index for fast similarity search | |
metric = faiss.METRIC_INNER_PRODUCT | |
vectors = numpy.stack(data["embedding"].tolist(), axis=0) | |
index = faiss.IndexFlatL2(len(data["embedding"][0])) | |
index.metric_type = metric | |
faiss.normalize_L2(vectors) | |
index.train(vectors) | |
index.add(vectors) | |
# Load the model for later use in embeddings | |
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME) | |
def search(query: str, k: int) -> tuple[str, str]: | |
""" | |
Searches the dataset for the top k most relevant papers to the query | |
Args: | |
query (str): The user's query | |
k (int): The number of results to return | |
Returns: | |
tuple[str, str]: A tuple containing the search results and references | |
""" | |
query = numpy.expand_dims(model.encode(query), axis=0) | |
faiss.normalize_L2(query) | |
D, I = index.search(query, k) | |
top_five = data.loc[I[0]] | |
search_results = ( | |
"You are an AI assistant who delights in helping people learn about research from the Design " | |
"Research Collective. Here are several abstracts from really cool, and really relevant, " | |
"papers:\n\n" | |
) | |
references = "\n\n## References\n\n" | |
for i in range(k): | |
search_results += top_five["bib_dict"].values[i]["abstract"] + "\n" | |
references += ( | |
str(i + 1) | |
+ ". " | |
+ ", ".join( | |
[ | |
author.split(" ")[-1] | |
for author in top_five["bib_dict"] | |
.values[i]["author"] | |
.split(" and ") | |
] | |
) | |
+ ". (" | |
+ str(int(top_five["bib_dict"].values[i]["pub_year"])) | |
+ "). [" | |
+ top_five["bib_dict"].values[i]["title"] | |
+ "]" | |
+ "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" | |
+ top_five["author_pub_id"].values[i] | |
+ ").\n" | |
) | |
search_results += ( | |
"\nUsing the information provided above, respond to this query: " | |
) | |
return search_results, references | |
# Create an LLM pipeline that we can send queries to | |
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME) | |
streamer = transformers.TextIteratorStreamer( | |
tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
chatmodel = transformers.AutoModelForCausalLM.from_pretrained( | |
LLM_MODEL_NAME, torch_dtype="auto", device_map="auto" | |
) | |
def preprocess(message: str) -> tuple[str, str]: | |
""" | |
Applies a preprocessing step to the user's message before the LLM receives it | |
Args: | |
message (str): The user's message | |
Returns: | |
tuple[str, str]: A tuple containing the preprocessed message and a bypass variable | |
""" | |
block_search_results, formatted_search_results = search(message, 5) | |
return block_search_results + message, formatted_search_results | |
def postprocess(response: str, bypass_from_preprocessing: str) -> str: | |
""" | |
Applies a postprocessing step to the LLM's response before the user receives it | |
Args: | |
response (str): The LLM's response | |
bypass_from_preprocessing (str): The bypass variable from the preprocessing step | |
Returns: | |
str: The postprocessed response | |
""" | |
return response + bypass_from_preprocessing | |
def reply(message: str, history: list[str]) -> str: | |
""" | |
This function is responsible for crafting a response | |
Args: | |
message (str): The user's message | |
history (list[str]): The conversation history | |
Returns: | |
str: The AI's response | |
""" | |
# Apply preprocessing | |
message, bypass = preprocess(message) | |
# This is some handling that is applied to the history variable to put it in a good format | |
history_transformer_format = [ | |
{"role": role, "content": message_pair[idx]} | |
for message_pair in history | |
for idx, role in enumerate(["user", "assistant"]) | |
if message_pair[idx] is not None | |
] + [{"role": "user", "content": message}] | |
# Stream a response from pipe | |
text = tokenizer.apply_chat_template( | |
history_transformer_format, tokenize=False, add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0") | |
generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512) | |
t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_message = "" | |
for new_token in streamer: | |
if new_token != "<": | |
partial_message += new_token | |
time.sleep(0.05) | |
yield partial_message | |
yield partial_message + bypass | |
# Create and run the gradio interface | |
gradio.ChatInterface( | |
reply, | |
examples=EXAMPLE_QUERIES, | |
chatbot=gradio.Chatbot( | |
show_label=False, show_copy_button=True, value=[[None, GREETING]] | |
), | |
retry_btn=None, | |
undo_btn=None, | |
clear_btn=None, | |
cache_examples=True, | |
fill_height=True, | |
).launch(debug=True) | |