MAPI_LLM / agent.py
maykcaldas's picture
Update agent.py
231d7e5
from mapi_tools import MAPI_class_tools, MAPI_reg_tools
from utils import common_tools
from reaction_prediction import SynthesisReactions
from langchain import OpenAI
from gpt_index import GPTListIndex, GPTIndexMemory
from langchain import agents
from langchain.agents import initialize_agent
stability = MAPI_class_tools(
"is_stable","stable","Stable","Unstable"
)
magnetism = MAPI_class_tools(
"is_magnetic","magnetic","Magnetic","Not magnetic"
)
metal = MAPI_class_tools(
"is_metal","metallic","Metal","Not metal"
)
gap_direct = MAPI_class_tools(
"is_gap_direct","gap direct","Gap direct","Gap indirect"
)
band_gap = MAPI_reg_tools(
"band_gap","band gap"
)
energy_per_atom = MAPI_reg_tools(
"energy_per_atom","energy per atom gap"
)
formation_energy_per_atom = MAPI_reg_tools(
"formation_energy_per_atom","formation energy per atom gap"
)
volume = MAPI_reg_tools(
"volume","volume"
)
density = MAPI_reg_tools(
"density","density"
)
atomic_density = MAPI_reg_tools(
"density_atomic","atomic density"
)
electronic_energy = MAPI_reg_tools(
"e_electronic","electronic energy"
)
ionic_energy = MAPI_reg_tools(
"e_ion","cationic energy"
)
total_energy = MAPI_reg_tools(
"e_total","total energy"
)
reaction = SynthesisReactions()
class Agent:
def __init__(self, openai_api_key, mapi_api_key):
memory = GPTIndexMemory(index=GPTListIndex([]), memory_key="chat_history", query_kwargs={"response_mode": "compact"})
llm=OpenAI(temperature=0.7)
tools = (
stability.get_tools() +
magnetism.get_tools() +
gap_direct.get_tools() +
metal.get_tools() +
band_gap.get_tools() +
volume.get_tools() +
density.get_tools() +
atomic_density.get_tools() +
formation_energy_per_atom.get_tools() +
energy_per_atom.get_tools() +
electronic_energy.get_tools() +
ionic_energy.get_tools() +
total_energy.get_tools() +
reaction.get_tools() +
agents.load_tools(["llm-math", "python_repl"], llm=llm) +
common_tools
)
self.agent_chain = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True, memory=memory)
def run(self, query):
return self.agent_chain.run(input=query)