|
import base64 |
|
import json |
|
from openai import AzureOpenAI |
|
import os |
|
import sys |
|
sys.path.append('./rxn/') |
|
import torch |
|
import json |
|
from getReaction import get_reaction |
|
|
|
|
|
|
|
class RXNIM: |
|
|
|
def __init__(self, api_version='2024-06-01', azure_endpoint='https://hkust.azure-api.net'): |
|
|
|
self.API_KEY = os.environ.get('key') |
|
if not self.API_KEY: |
|
raise ValueError("Environment variable 'KEY' not set.") |
|
|
|
|
|
self.client = AzureOpenAI( |
|
api_key=self.API_KEY, |
|
api_version=api_version, |
|
azure_endpoint=azure_endpoint, |
|
) |
|
|
|
|
|
self.tools = [ |
|
{ |
|
'type': 'function', |
|
'function': { |
|
'name': 'get_reaction', |
|
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.', |
|
'parameters': { |
|
'type': 'object', |
|
'properties': { |
|
'image_path': { |
|
'type': 'string', |
|
'description': 'The path to the reaction image.', |
|
}, |
|
}, |
|
'required': ['image_path'], |
|
'additionalProperties': False, |
|
}, |
|
}, |
|
}, |
|
] |
|
|
|
|
|
self.TOOL_MAP = { |
|
'get_reaction': get_reaction, |
|
} |
|
|
|
def encode_image(self, image_path: str): |
|
'''Returns a base64 string of the input image.''' |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
def process(self, image_path: str, prompt_path: str): |
|
|
|
base64_image = self.encode_image(image_path) |
|
|
|
|
|
with open(prompt_path, 'r') as prompt_file: |
|
prompt = prompt_file.read() |
|
|
|
|
|
messages = [ |
|
{'role': 'system', 'content': 'You are a helpful assistant. Before providing the final answer, consider if any additional information or tool usage is needed to improve your response.'}, |
|
{ |
|
'role': 'user', |
|
'content': [ |
|
{ |
|
'type': 'text', |
|
'text': prompt |
|
}, |
|
{ |
|
'type': 'image_url', |
|
'image_url': { |
|
'url': f'data:image/png;base64,{base64_image}' |
|
} |
|
} |
|
] |
|
}, |
|
] |
|
|
|
MAX_ITERATIONS = 5 |
|
iterations = 0 |
|
|
|
while iterations < MAX_ITERATIONS: |
|
iterations += 1 |
|
print(f'Iteration {iterations}') |
|
|
|
|
|
response = self.client.chat.completions.create( |
|
model='gpt-4o', |
|
temperature=0, |
|
response_format={'type': 'json_object'}, |
|
messages=messages, |
|
tools=self.tools, |
|
) |
|
|
|
|
|
assistant_message = response.choices[0].message |
|
|
|
|
|
messages.append(assistant_message) |
|
|
|
|
|
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls: |
|
tool_calls = assistant_message.tool_calls |
|
results = [] |
|
|
|
for tool_call in tool_calls: |
|
tool_name = tool_call.function.name |
|
tool_arguments = tool_call.function.arguments |
|
tool_call_id = tool_call.id |
|
|
|
tool_args = json.loads(tool_arguments) |
|
|
|
if tool_name in self.TOOL_MAP: |
|
try: |
|
|
|
tool_result = self.TOOL_MAP[tool_name](image_path) |
|
print(f'{tool_name} result: {tool_result}') |
|
except Exception as e: |
|
tool_result = {'error': str(e)} |
|
else: |
|
tool_result = {'error': f"Unknown tool called: {tool_name}"} |
|
|
|
|
|
results.append({ |
|
'role': 'tool', |
|
'content': json.dumps({ |
|
'image_path': image_path, |
|
f'{tool_name}': tool_result, |
|
}), |
|
'tool_call_id': tool_call_id, |
|
}) |
|
print(results) |
|
|
|
|
|
messages.extend(results) |
|
else: |
|
|
|
break |
|
|
|
else: |
|
|
|
return "The assistant could not complete the task within the maximum number of iterations." |
|
|
|
|
|
final_content = assistant_message.content |
|
return final_content |
|
|
|
|
|
|
|
|