Spaces:
Sleeping
Sleeping
lukestanley
commited on
Commit
•
64c61f3
1
Parent(s):
41ac6cc
Add Anthropic Opus support
Browse files
utils.py
CHANGED
@@ -24,9 +24,9 @@ from huggingface_hub import hf_hub_download
|
|
24 |
|
25 |
URL = "http://localhost:5834/v1/chat/completions"
|
26 |
in_memory_llm = None
|
27 |
-
worker_options = ["runpod", "http", "in_memory", "mistral"]
|
28 |
|
29 |
-
LLM_WORKER = env.get("LLM_WORKER", "
|
30 |
if LLM_WORKER not in worker_options:
|
31 |
raise ValueError(f"Invalid worker: {LLM_WORKER}")
|
32 |
N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
|
@@ -250,11 +250,62 @@ def llm_stream_mistral_api(prompt: str, pydantic_model_class=None, attempts=0) -
|
|
250 |
return json.loads(output)
|
251 |
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
def query_ai_prompt(prompt, replacements, model_class):
|
255 |
prompt = replace_text(prompt, replacements)
|
256 |
-
if LLM_WORKER == "
|
257 |
-
result =
|
258 |
if LLM_WORKER == "mistral":
|
259 |
result = llm_stream_mistral_api(prompt, model_class)
|
260 |
if LLM_WORKER == "runpod":
|
|
|
24 |
|
25 |
URL = "http://localhost:5834/v1/chat/completions"
|
26 |
in_memory_llm = None
|
27 |
+
worker_options = ["runpod", "http", "in_memory", "mistral", "anthropic"]
|
28 |
|
29 |
+
LLM_WORKER = env.get("LLM_WORKER", "anthropic")
|
30 |
if LLM_WORKER not in worker_options:
|
31 |
raise ValueError(f"Invalid worker: {LLM_WORKER}")
|
32 |
N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
|
|
|
250 |
return json.loads(output)
|
251 |
|
252 |
|
253 |
+
def send_anthropic_request(prompt: str):
|
254 |
+
api_key = env.get("ANTHROPIC_API_KEY")
|
255 |
+
if not api_key:
|
256 |
+
print("API key not found. Please set the ANTHROPIC_API_KEY environment variable.")
|
257 |
+
return
|
258 |
+
|
259 |
+
headers = {
|
260 |
+
'x-api-key': api_key,
|
261 |
+
'anthropic-version': '2023-06-01',
|
262 |
+
'Content-Type': 'application/json',
|
263 |
+
}
|
264 |
+
|
265 |
+
data = {
|
266 |
+
"model": "claude-3-opus-20240229",
|
267 |
+
"max_tokens": 1024,
|
268 |
+
"messages": [{"role": "user", "content": prompt}]
|
269 |
+
}
|
270 |
+
|
271 |
+
response = requests.post('https://api.anthropic.com/v1/messages', headers=headers, data=json.dumps(data))
|
272 |
+
if response.status_code != 200:
|
273 |
+
print(f"Unexpected Anthropic API status code: {response.status_code} with body: {response.text}")
|
274 |
+
raise ValueError(f"Unexpected Anthropic API status code: {response.status_code} with body: {response.text}")
|
275 |
+
j = response.json()
|
276 |
+
|
277 |
+
text = j['content'][0]["text"]
|
278 |
+
print(text)
|
279 |
+
return text
|
280 |
+
|
281 |
+
def llm_anthropic_api(prompt: str, pydantic_model_class=None, attempts=0) -> Union[str, Dict[str, Any]]:
|
282 |
+
# With no streaming or rate limits, we use the Anthropic API, we have string input and output from send_anthropic_request,
|
283 |
+
# but we need to convert it to JSON for the pydantic model class like the other APIs.
|
284 |
+
output = send_anthropic_request(prompt)
|
285 |
+
if pydantic_model_class:
|
286 |
+
try:
|
287 |
+
parsed_result = pydantic_model_class.model_validate_json(output)
|
288 |
+
print(parsed_result)
|
289 |
+
# This will raise an exception if the model is invalid.
|
290 |
+
return json.loads(output)
|
291 |
+
except Exception as e:
|
292 |
+
print(f"Error validating pydantic model: {e}")
|
293 |
+
# Let's retry by calling ourselves again if attempts < 3
|
294 |
+
if attempts == 0:
|
295 |
+
# We modify the prompt to remind it to output JSON in the required format
|
296 |
+
prompt = f"{prompt} You must output the JSON in the required format only, with no remarks or prefacing remarks - JUST JSON!"
|
297 |
+
if attempts < 3:
|
298 |
+
attempts += 1
|
299 |
+
print(f"Retrying Anthropic API call, attempt {attempts}")
|
300 |
+
return llm_anthropic_api(prompt, pydantic_model_class, attempts)
|
301 |
+
else:
|
302 |
+
print("No pydantic model class provided, returning without class validation")
|
303 |
+
return json.loads(output)
|
304 |
|
305 |
def query_ai_prompt(prompt, replacements, model_class):
|
306 |
prompt = replace_text(prompt, replacements)
|
307 |
+
if LLM_WORKER == "anthropic":
|
308 |
+
result = llm_anthropic_api(prompt, model_class)
|
309 |
if LLM_WORKER == "mistral":
|
310 |
result = llm_stream_mistral_api(prompt, model_class)
|
311 |
if LLM_WORKER == "runpod":
|