RxnIM / rxnim.py
CYF200127's picture
Update rxnim.py
6b368e8 verified
import base64
import json
from openai import AzureOpenAI
import os
import sys
sys.path.append('./rxn/')
import torch
import json
from getReaction import get_reaction
class RXNIM:
def __init__(self, api_version='2024-06-01', azure_endpoint='https://hkust.azure-api.net'):
# δ»ŽηŽ―ε’ƒε˜ι‡θ―»ε– API Key
self.API_KEY = os.environ.get('key')
if not self.API_KEY:
raise ValueError("Environment variable 'KEY' not set.")
# Set up client
self.client = AzureOpenAI(
api_key=self.API_KEY,
api_version=api_version,
azure_endpoint=azure_endpoint,
)
# Define tools
self.tools = [
{
'type': 'function',
'function': {
'name': 'get_reaction',
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
'parameters': {
'type': 'object',
'properties': {
'image_path': {
'type': 'string',
'description': 'The path to the reaction image.',
},
},
'required': ['image_path'],
'additionalProperties': False,
},
},
},
]
# Define tool mapping
self.TOOL_MAP = {
'get_reaction': get_reaction,
}
def encode_image(self, image_path: str):
'''Returns a base64 string of the input image.'''
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def process(self, image_path: str, prompt_path: str):
# Encode image
base64_image = self.encode_image(image_path)
# Read prompt
with open(prompt_path, 'r') as prompt_file:
prompt = prompt_file.read()
# Build initial messages
messages = [
{'role': 'system', 'content': 'You are a helpful assistant. Before providing the final answer, consider if any additional information or tool usage is needed to improve your response.'},
{
'role': 'user',
'content': [
{
'type': 'text',
'text': prompt
},
{
'type': 'image_url',
'image_url': {
'url': f'data:image/png;base64,{base64_image}'
}
}
]
},
]
MAX_ITERATIONS = 5
iterations = 0
while iterations < MAX_ITERATIONS:
iterations += 1
print(f'Iteration {iterations}')
# Call the model
response = self.client.chat.completions.create(
model='gpt-4o',
temperature=0,
response_format={'type': 'json_object'},
messages=messages,
tools=self.tools,
)
# Get assistant's message
assistant_message = response.choices[0].message
# Add assistant's message to messages
messages.append(assistant_message)
# Check for tool calls
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
tool_calls = assistant_message.tool_calls
results = []
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_arguments = tool_call.function.arguments
tool_call_id = tool_call.id
tool_args = json.loads(tool_arguments)
if tool_name in self.TOOL_MAP:
try:
# Call the tool function
tool_result = self.TOOL_MAP[tool_name](image_path)
print(f'{tool_name} result: {tool_result}')
except Exception as e:
tool_result = {'error': str(e)}
else:
tool_result = {'error': f"Unknown tool called: {tool_name}"}
# Append tool result to messages
results.append({
'role': 'tool',
'content': json.dumps({
'image_path': image_path,
f'{tool_name}': tool_result,
}),
'tool_call_id': tool_call_id,
})
print(results)
# Add tool results to messages
messages.extend(results)
else:
# No more tool calls, assume task is completed
break
else:
# Exceeded maximum iterations
return "The assistant could not complete the task within the maximum number of iterations."
# Return the final assistant message
final_content = assistant_message.content
return final_content