File size: 10,674 Bytes
a0db240 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
# pylint: disable=W0707
# pylint: disable=W0719
import os
import json
import tiktoken
import openai
from openai import OpenAI
import requests
from constants.cli import OPENAI_MODELS
from constants.ai import SYSTEM_PROMPT, PROMPT, API_URL
def retrieve(query, k=10, filters=None):
"""Retrieves and returns dict.
Args:
query (str): User query to pass in
openai_api_key (str): openai api key. If not passed in, uses environment variable
k (int, optional): number of results passed back. Defaults to 10.
filters (dict, optional): Filters to apply to the query. You can filter based off
any piece of metadata by passing in a dict of the format {metadata_name: filter_value}
ie {"library_id": "1234"}.
See the README for more details:
https://github.com/fleet-ai/context/tree/main#using-fleet-contexts-rich-metadata
Returns:
list: List of queried results
"""
url = f"{API_URL}/query"
params = {
"query": query,
"dataset": "python_libraries",
"n_results": k,
"filters": filters,
}
return requests.post(url, json=params, timeout=120).json()
def retrieve_context(query, openai_api_key, k=10, filters=None):
"""Gets the context from our libraries vector db for a given query.
Args:
query (str): User input query
k (int, optional): number of retrieved results. Defaults to 10.
"""
# First, we query the API
responses = retrieve(query, k=k, filters=filters)
# Then, we build the prompt_with_context string
prompt_with_context = ""
for response in responses:
prompt_with_context += f"\n\n### Context {response['metadata']['url']} ###\n{response['metadata']['text']}"
return {"role": "user", "content": prompt_with_context}
def construct_prompt(
messages,
context_message,
model="gpt-4-1106-preview",
cite_sources=True,
context_window=3000,
):
"""
Constructs a RAG (Retrieval-Augmented Generation) prompt by balancing the token count of messages and context_message.
If the total token count exceeds the maximum limit, it adjusts the token count of each to maintain a 1:1 proportion.
It then combines both lists and returns the result.
Parameters:
messages (List[dict]): List of messages to be included in the prompt.
context_message (dict): Context message to be included in the prompt.
model (str): The model to be used for encoding, default is "gpt-4-1106-preview".
Returns:
List[dict]: The constructed RAG prompt.
"""
# Get the encoding; default to cl100k_base
if model in OPENAI_MODELS:
encoding = tiktoken.encoding_for_model(model)
else:
encoding = tiktoken.get_encoding("cl100k_base")
# 1) calculate tokens
reserved_space = 1000
max_messages_count = int((context_window - reserved_space) / 2)
max_context_count = int((context_window - reserved_space) / 2)
# 2) construct prompt
prompts = messages.copy()
prompts.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
if cite_sources:
prompts.insert(-1, {"role": "user", "content": PROMPT})
# 3) find how many tokens each list has
messages_token_count = len(
encoding.encode(
"\n".join(
[
f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>"
for message in prompts
]
)
)
)
context_token_count = len(
encoding.encode(
f"<|im_start|>{context_message['role']}\n{context_message['content']}<|im_end|>"
)
)
# 4) Balance the token count for each
if (messages_token_count + context_token_count) > (context_window - reserved_space):
# context has more than limit, messages has less than limit
if (messages_token_count < max_messages_count) and (
context_token_count > max_context_count
):
max_context_count += max_messages_count - messages_token_count
# messages has more than limit, context has less than limit
elif (messages_token_count > max_messages_count) and (
context_token_count < max_context_count
):
max_messages_count += max_context_count - context_token_count
# 5) Cut each list to the max count
# Cut down messages
while messages_token_count > max_messages_count:
removed_encoding = encoding.encode(
f"<|im_start|>{prompts[1]['role']}\n{prompts[1]['content']}<|im_end|>"
)
messages_token_count -= len(removed_encoding)
if messages_token_count < max_messages_count:
prompts = (
[prompts[0]]
+ [
{
"role": prompts[1]["role"],
"content": encoding.decode(
removed_encoding[
: min(
int(max_messages_count -
messages_token_count),
len(removed_encoding),
)
]
)
.replace("<|im_start|>", "")
.replace("<|im_end|>", ""),
}
]
+ prompts[2:]
)
else:
prompts = [prompts[0]] + prompts[2:]
# Cut down context
if context_token_count > max_context_count:
# Taking a proportion of the content chars length
reduced_chars_length = int(
len(context_message["content"]) *
(max_context_count / context_token_count)
)
context_message["content"] = context_message["content"][:reduced_chars_length]
# 6) Combine both lists
prompts.insert(-1, context_message)
return prompts
def get_remote_chat_response(messages, model="gpt-4-1106-preview"):
"""
Returns a streamed OpenAI chat response.
Parameters:
messages (List[dict]): List of messages to be included in the prompt.
model (str): The model to be used for encoding, default is "gpt-4-1106-preview".
Returns:
str: The streamed OpenAI chat response.
"""
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
try:
response = client.chat.completions.create(
model=model, messages=messages, temperature=0.2, stream=True
)
for chunk in response:
current_context = chunk.choices[0].delta.content
yield current_context
except openai.AuthenticationError as error:
print("401 Authentication Error:", error)
raise Exception(
"Invalid OPENAI_API_KEY. Please re-run with a valid key.")
except Exception as error:
print("Streaming Error:", error)
raise Exception("Internal Server Error")
def get_other_chat_response(messages, model="local-model"):
"""
Returns a streamed chat response from a local server.
Parameters:
messages (List[dict]): List of messages to be included in the prompt.
model (str): The model to be used for encoding, default is "gpt-4-1106-preview".
Returns:
str: The streamed chat response.
"""
try:
if model == "local-model":
url = "http://localhost:1234/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"messages": messages,
"temperature": 0.2,
"max_tokens": -1,
"stream": True,
}
response = requests.post(
url, headers=headers, data=json.dumps(data), stream=True, timeout=120
)
if response.status_code == 200:
for chunk in response.iter_content(chunk_size=None):
decoded_chunk = chunk.decode()
if (
"data:" in decoded_chunk
and decoded_chunk.split("data:")[1].strip()
): # Check if the chunk is not empty
try:
chunk_dict = json.loads(
decoded_chunk.split("data:")[1].strip()
)
yield chunk_dict["choices"][0]["delta"].get("content", "")
except json.JSONDecodeError:
pass
else:
print(f"Error: {response.status_code}, {response.text}")
raise Exception("Internal Server Error")
else:
if not os.environ.get("OPENROUTER_API_KEY"):
raise Exception(
f"For non-OpenAI models, like {model}, set your OPENROUTER_API_KEY."
)
response = requests.post(
url="https://openrouter.ai/api/v1/chat/completions",
headers={
"Authorization": f"Bearer {os.environ.get('OPENROUTER_API_KEY')}",
"HTTP-Referer": os.environ.get(
"OPENROUTER_APP_URL", "https://fleet.so/context"
),
"X-Title": os.environ.get("OPENROUTER_APP_TITLE", "Fleet Context"),
"Content-Type": "application/json",
},
data=json.dumps(
{"model": model, "messages": messages, "stream": True}),
stream=True,
timeout=120,
)
if response.status_code == 200:
for chunk in response.iter_lines():
decoded_chunk = chunk.decode("utf-8")
if (
"data:" in decoded_chunk
and decoded_chunk.split("data:")[1].strip()
): # Check if the chunk is not empty
try:
chunk_dict = json.loads(
decoded_chunk.split("data:")[1].strip()
)
yield chunk_dict["choices"][0]["delta"].get("content", "")
except json.JSONDecodeError:
pass
else:
print(f"Error: {response.status_code}, {response.text}")
raise Exception("Internal Server Error")
except requests.exceptions.RequestException as error:
print("Request Error:", error)
raise Exception(
"Invalid request. Please check your request parameters.") |