Spaces:
Runtime error
Runtime error
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
import concurrent.futures | |
from collections import defaultdict | |
import pandas as pd | |
import numpy as np | |
import json | |
import pickle | |
import pprint | |
from io import StringIO | |
import textwrap | |
import time | |
import re | |
from openai import OpenAI | |
openai_client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
import octoai | |
octoai_client = octoai.client.Client(token=os.getenv('OCTOML_KEY')) | |
from pinecone import Pinecone, ServerlessSpec | |
pc = Pinecone(api_key=os.getenv('PINECONE_API_KEY')) | |
pc_256 = pc.Index('prorata-postman-ds-256-v2') | |
pc_128 = pc.Index('prorata-postman-ds-128-v2') | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
sentence_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=128, | |
chunk_overlap=0, | |
separators=["\n\n", "\n", "."], | |
keep_separator=False | |
) | |
from functools import cache | |
def get_embedding(text, model="text-embedding-3-small"): | |
text = text.replace("\n", " ") | |
return openai_client.embeddings.create(input = [text], model=model).data[0].embedding | |
def get_embedding_l(text_l, model="text-embedding-3-small"): | |
text_l = [text.replace("\n", " ") for text in text_l] | |
res = openai_client.embeddings.create(input=text_l, model=model) | |
embeds = [record.embedding for record in res.data] | |
return embeds | |
def parse_json_string(content): | |
fixed_content = content | |
for _ in range(20): | |
try: | |
result = json.loads(fixed_content) | |
break | |
except Exception as e: | |
print(e) | |
if "Expecting ',' delimiter" in str(e): | |
# "Expecting , delimiter: line x column y (char d)" | |
idx = int(re.findall(r'\(char (\d+)\)', str(e))[0]) | |
fixed_content = fixed_content[:idx] + ',' + fixed_content[idx:] | |
print(fixed_content) | |
print() | |
elif "Expecting property name enclosed in double quotes" in str(e): | |
# Expecting property name enclosed in double quotes: line x column y (char d) | |
idx = int(re.findall(r'\(char (\d+)\)', str(e))[0]) | |
fixed_content = fixed_content[:idx-1] + '}' + fixed_content[idx:] | |
print(fixed_content) | |
print() | |
else: | |
raise ValueError(str(e)) | |
return result | |
# prompt_af_template_llama3 = "Please breakdown the following paragraph into independent and atomic facts. Format your response as a signle JSON object, a list of facts:\n\n{}" | |
prompt_af_template_llama3 = "Please breakdown the following paragraph into independent and atomic facts. Format your response in JSON as a list of 'fact' objects:\n\n{}" | |
# prompt_tf_template = "Given the context below, anwer the question that follows. Please format your answer in JSON with a yes or no determination and rationale for the determination. \n\nContext: {}\n\nQuestion: {} Is this claim true or false?" | |
# prompt_tf_template = "Given the context below, anwer the question that follows. Please format your answer in JSON with a yes or no determination and rationale for the determination. \n\nContext: {}\n\nQuestion: <{}> Is the previous claim (in between <> braces) true or false?" | |
prompt_tf_template = "Given the context below, anwer the question that follows. Please format your answer in JSON with a yes or no determination and rationale for the determination. \n\nContext: {}\n\nQuestion: <{}> Does the context explicitly support the previous claim (in between <> braces), true or false?" | |
def get_atoms_list(answer, file=None): | |
prompt_af = prompt_af_template_llama3.format(answer) | |
response, atoms_l = None, [] | |
for _ in range(5): | |
try: | |
# response = octoai_client.chat.completions.create( | |
# model="meta-llama-3-70b-instruct", | |
# messages=[ | |
# {"role": "system", "content": "You are a helpful assistant."}, | |
# {"role": "user", "content": prompt_af} | |
# ], | |
# # response_format={"type": "json_object"}, | |
# max_tokens=512, | |
# presence_penalty=0, | |
# temperature=0.1, | |
# top_p=0.9, | |
# ) | |
response = octoai_client.chat.completions.create( | |
model="meta-llama-3-70b-instruct", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt_af} | |
], | |
# response_format={"type": "json_object"}, | |
max_tokens=512, | |
presence_penalty=0, | |
temperature=0.1, | |
top_p=0.9, | |
) | |
content = response.choices[0].message.content | |
idx1 = content.find('```') | |
idx2 = idx1+3 + content[idx1+3:].find('```') | |
# atoms_l = json.loads(content[idx1+3:idx2]) | |
atoms_l = parse_json_string(content[idx1+3:idx2]) | |
atoms_l = [a['fact'] for a in atoms_l] | |
break | |
except Exception as error: | |
print(error, file=file) | |
print(response, file=file) | |
print(content[idx1+3:idx2], file=file) | |
time.sleep(2) | |
return atoms_l | |
def get_topk_matches(atom, k=5, pc_index=pc_256): | |
embed_atom = get_embedding(atom) | |
res = pc_index.query(vector=embed_atom, top_k=k, include_metadata=True) | |
return res['matches'] | |
def get_match_atom_entailment_determination(_match, atom, file=None): | |
prompt_tf = prompt_tf_template.format(_match['metadata']['text'], atom) | |
response = None | |
chunk_determination = {} | |
chunk_determination['chunk_id'] = _match['id'] | |
chunk_determination['true'] = False | |
for _ in range(5): | |
try: | |
response = octoai_client.chat.completions.create( | |
model="meta-llama-3-70b-instruct", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt_tf} | |
], | |
# response_format={"type": "json_object"}, | |
max_tokens=512, | |
# presence_penalty=0, | |
temperature=0.1, | |
# top_p=0.9, | |
) | |
content = response.choices[0].message.content | |
idx1 = content.find('{') | |
idx2 = content.find('}') | |
chunk_determination.update(json.loads(content[idx1:idx2+1])) | |
_det_lower = chunk_determination['determination'].lower() | |
chunk_determination['true'] = "true" in _det_lower or "yes" in _det_lower | |
break | |
except Exception as error: | |
print(error, file=file) | |
print(prompt_tf, file=file) | |
print(response, file=file) | |
time.sleep(2) | |
return chunk_determination | |
def get_atom_support(atom, file=None): | |
topk_matches = get_topk_matches(atom) | |
atom_support = {} | |
for _match in topk_matches: | |
chunk_determination = atom_support.get(_match['metadata']['url'], {}) | |
if not chunk_determination or not chunk_determination['true']: | |
atom_support[_match['metadata']['url']] = get_match_atom_entailment_determination(_match, atom, file=file) | |
return atom_support | |
def get_atom_support_list(atoms_l, file=None): | |
return [get_atom_support(a, file=file) for a in atoms_l] | |
def credit_atom_support_list(atom_support_l): | |
num_atoms = len(atom_support_l) | |
credit_d = defaultdict(float) | |
for atom_support in atom_support_l: | |
atom_support_size = 0.0 | |
for url_determination_d in atom_support.values(): | |
if url_determination_d['true']: | |
atom_support_size += 1.0 | |
for url, url_determination_d in atom_support.items(): | |
if url_determination_d['true']: | |
credit_d[url] += 1.0 / atom_support_size | |
for url in credit_d.keys(): | |
credit_d[url] = credit_d[url] / num_atoms | |
return credit_d | |
def print_atom_support(atom_support, prefix='', file=None): | |
for url, chunk_determination in atom_support.items(): | |
print(f"{prefix}{url}:", file=file) | |
print(f"{prefix} Determination: {'YES' if chunk_determination['true'] else 'NO'}", file=file) | |
print(f"{prefix} Rationale: {chunk_determination['rationale']}", file=file) | |
def print_credit_dist(credit_dist, prefix='', url_to_id=None, file=None): | |
credit_l = [(url, w) for url, w in credit_dist.items()] | |
credit_l = sorted(credit_l, key=lambda x: x[1], reverse=True) | |
for url, w in credit_l: | |
if url_to_id is None: | |
print(f"{prefix}{url}: {100*w:.2f}%", file=file) | |
else: | |
print(f"{prefix}{url_to_id[url]} {url}: {100*w:.2f}%", file=file) | |
# concurrent LLM calls | |
def get_atom_topk_matches_l_concurrent(atoms_l, max_workers=4): | |
atom_topkmatches_l = [] | |
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [] | |
for atom in atoms_l: | |
futures.append(executor.submit(get_topk_matches, atom)) | |
for f in futures: | |
r = f.result() | |
atom_topkmatches_l.append(r) | |
return atom_topkmatches_l | |
def aggregate_atom_topkmatches_l(atom_topkmatches_l): | |
atom_url_to_aggmacth_maps_l = [] | |
for atom_topkmatches in atom_topkmatches_l: | |
atom_url_to_aggmatch_map = {} | |
atom_url_to_aggmacth_maps_l.append(atom_url_to_aggmatch_map) | |
for _match in atom_topkmatches: | |
if _match['metadata']['url'] not in atom_url_to_aggmatch_map: | |
match_copy = {} | |
match_copy['id'] = _match['id'] | |
match_copy['id_l'] = [_match['id']] | |
match_copy['offset_l'] = [0] | |
match_copy['score'] = _match['score'] | |
match_copy['values'] = _match['values'] | |
# TODO: change to list of chunks and then append at query time | |
match_copy['metadata'] = {} | |
match_copy['metadata']['url'] = _match['metadata']['url'] | |
match_copy['metadata']['chunk'] = _match['metadata']['chunk'] | |
match_copy['metadata']['text'] = _match['metadata']['text'] | |
match_copy['metadata']['title'] = _match['metadata']['title'] | |
atom_url_to_aggmatch_map[_match['metadata']['url']] = match_copy | |
else: | |
prev_match = atom_url_to_aggmatch_map[_match['metadata']['url']] | |
prev_match['id_l'].append(_match['id']) | |
prev_match['offset_l'].append(len(prev_match['metadata']['text'])) | |
prev_match['metadata']['text'] += f"\n\n{_match['metadata']['text']}" | |
atomidx_w_single_url_aggmatch_l = [] | |
for idx, atom_url_to_aggmatch_map in enumerate(atom_url_to_aggmacth_maps_l): | |
for agg_match in atom_url_to_aggmatch_map.values(): | |
atomidx_w_single_url_aggmatch_l.append((idx, agg_match)) | |
return atomidx_w_single_url_aggmatch_l | |
def get_atmom_support_l_from_atomidx_w_single_url_aggmatch_l_concurrent(atoms_l, atomidx_w_single_url_aggmatch_l, max_workers=4): | |
atom_support_l = [{} for _ in atoms_l] | |
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [] | |
for atomidx_w_single_url_aggmatch in atomidx_w_single_url_aggmatch_l: | |
futures.append(executor.submit( | |
get_match_atom_entailment_determination, | |
atomidx_w_single_url_aggmatch[1], | |
atoms_l[atomidx_w_single_url_aggmatch[0]], | |
) | |
) | |
for f, atomidx_w_single_url_aggmatch in zip(futures, atomidx_w_single_url_aggmatch_l): | |
aggmatch_determination = f.result() | |
atom_support = atom_support_l[atomidx_w_single_url_aggmatch[0]] | |
atom_support[atomidx_w_single_url_aggmatch[1]['metadata']['url']] = aggmatch_determination | |
return atom_support_l | |
style_str = """ | |
<style> | |
.doc-title { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
width: 100%; | |
display: inline-block; | |
font-size: 2em; | |
font-weight: bolder; | |
padding-top: 20px; | |
/* font-style: italic; */ | |
} | |
.doc-url { | |
/* font-family: cursive, sans-serif; */ | |
font-size: 1em; | |
padding-left: 40px; | |
padding-bottom: 10px; | |
/* font-weight: bolder; */ | |
/* font-style: italic; */ | |
} | |
.doc-text { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
font-size: 1.5em; | |
white-space: pre-wrap; | |
padding-left: 40px; | |
padding-bottom: 20px; | |
/* font-weight: bolder; */ | |
/* font-style: italic; */ | |
} | |
.doc-text .chunk-separator { | |
/* font-style: italic; */ | |
color: #0000FF; | |
} | |
.doc-title > img { | |
width: 22px; | |
height: 22px; | |
border-radius: 50%; | |
overflow: hidden; | |
background-color: transparent; | |
display: inline-block; | |
vertical-align: middle; | |
} | |
.doc-title > score { | |
font-family: Optima, sans-serif; | |
font-weight: normal; | |
float: right; | |
} | |
</style> | |
""" | |