MAPI_LLM / reaction_prediction.py
maykcaldas's picture
Upload 7 files
77cbf82
import os
import re
from langchain.agents import Tool, tool
# from mp_api.client import MPRester
from pymatgen.ext.matproj import MPRester
from rxn_network.entries.entry_set import GibbsEntrySet
from rxn_network.enumerators.basic import BasicEnumerator
class SynthesisReactions:
def __init__(self, temp=900, stabl=0.025, exclusive_precursors=False, exclusive_targets=False):
self.temp = temp
self.stabl = stabl
self.exclusive_precursors = exclusive_precursors
self.exclusive_targets = exclusive_targets
def _split_string(self, s):
if isinstance(s, list):
s = "".join(s)
parts = re.findall('[a-z]+|[A-Z][a-z]*', s)
letters_only = [re.sub(r'\d+', '', part) for part in parts]
unique_letters = list(set(letters_only))
result = "-".join(unique_letters)
return result
def _get_rxn_from_precursor(self, precursors_formulas):
prec = precursors_formulas.split(',') if "," in precursors_formulas else precursors_formulas
with MPRester(os.getenv("MAPI_API_KEY")) as mpr:
entries = mpr.get_entries_in_chemsys(self._split_string(prec))
gibbs_entries = GibbsEntrySet.from_computed_entries(entries, self.temp)
filtered_entries = gibbs_entries.filter_by_stability(self.stabl)
prec = [prec] if isinstance(prec, str) else prec
be = BasicEnumerator(precursors=prec, exclusive_precursors=self.exclusive_precursors)
rxns = be.enumerate(filtered_entries)
try:
rxn_choice = next(iter(rxns))
return str(rxn_choice)
except:
return "Error: No reactions found."
def _get_rxn_from_target(self, targets_formulas):
targets = targets_formulas.split(',') if "," in targets_formulas else targets_formulas
with MPRester(os.getenv("MAPI_API_KEY")) as mpr:
entries = mpr.get_entries_in_chemsys(self._split_string(targets))
gibbs_entries = GibbsEntrySet.from_computed_entries(entries, self.temp)
filtered_entries = gibbs_entries.filter_by_stability(self.stabl)
targets = [targets] if isinstance(targets, str) else targets
be = BasicEnumerator(targets=targets, exclusive_targets=self.exclusive_targets)
rxns = be.enumerate(filtered_entries)
try:
rxn_choice = next(iter(rxns))
return str(rxn_choice)
except:
return "Error: No reactions found."
def _break_equation(self, equation):
pattern = r'(\d*\.?\d*\s*[A-Za-z]+\d*|\+|\->)'
pieces = re.findall(pattern, equation)
equation_pieces = []
current_piece = ''
for piece in pieces:
if piece == '+' or piece == '->':
equation_pieces.append(current_piece.strip())
equation_pieces.append(piece)
current_piece = ''
else:
current_piece += piece + ' '
equation_pieces.append(current_piece.strip())
return equation_pieces
def _convert_equation_pieces(self, equation_pieces):
if '+' in equation_pieces:
equation_pieces = [piece if piece != '+' else 'with' for piece in equation_pieces]
equation_pieces = [piece if piece != '->' else 'to yield' for piece in equation_pieces]
else:
equation_pieces = [piece if piece != '->' else 'yields' for piece in equation_pieces]
return equation_pieces
def _split_equation_pieces(self, equation_pieces):
new_pieces = []
for piece in equation_pieces:
if piece in ["with", "to yield", "yields"]:
new_pieces.append(piece)
else:
if re.match(r'^\d*\.\d+|\d+', piece):
number_match = re.match(r'^\d*\.\d+|\d+', piece)
number = number_match.group(0)
rest = piece[len(number):]
new_pieces.append(number)
new_pieces.append(rest)
else:
new_pieces.append("1")
new_pieces.append(piece)
return new_pieces
def _modify_mols(self, equation_pieces):
for i, piece in enumerate(equation_pieces):
if piece.replace('.', '', 1).isdigit():
equation_pieces[i] = f"{piece} mols"
return equation_pieces
def _combine_equation_pieces(self, equation_pieces):
if 'with' in equation_pieces:
equation_pieces.insert(0, 'mix')
combined_string = ' '.join(equation_pieces)
return combined_string
def _process_equation(self, equation):
equation_pieces = self._break_equation(equation)
converted_pieces = self._convert_equation_pieces(equation_pieces)
split_pieces = self._split_equation_pieces(converted_pieces)
modified_pieces = self._modify_mols(split_pieces)
combined_string = self._combine_equation_pieces(modified_pieces)
return combined_string
def get_reaction(self, input_string):
input_parts = input_string.split(',', 1)
if len(input_parts) != 2:
raise ValueError("Invalid input format. Expected 'precursor' or 'target', followed by a comma, and then the list of formulas separated by a comma.")
mode, formulas = input_parts
mode = mode.lower().strip()
if mode == "precursor":
reaction = self._get_rxn_from_precursor(formulas)
elif mode == "target":
reaction = self._get_rxn_from_target(formulas)
else:
raise ValueError("Invalid mode. Expected 'precursor' or 'target'.")
processed_reaction = self._process_equation(reaction)
return processed_reaction
def get_tools(self):
return [
Tool(
name = "Get a synthesis reaction for a material",
func = self.get_reaction,
description = (
"This function is useful for suggesting a synthesis reaction for a material. "
"Give this tool a string containing either precursor or target, then a comma, followed by the formulas separated by comma as input and returns a synthesis reaction."
"The mode is used to determine if the input is a precursor or a target material. "
)
)]