|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
from crewai import Crew |
|
from langchain_groq import ChatGroq |
|
import streamlit_ace as st_ace |
|
import traceback |
|
import contextlib |
|
import io |
|
from crewai_tools import FileReadTool |
|
import matplotlib.pyplot as plt |
|
import glob |
|
from dotenv import load_dotenv |
|
from autotabml_agents import initialize_agents |
|
from autotabml_tasks import create_tasks |
|
|
|
|
|
TEMP_DIR = "temp_dir" |
|
OUTPUT_DIR = "Output_dir" |
|
|
|
if not os.path.exists(TEMP_DIR): |
|
os.makedirs(TEMP_DIR) |
|
|
|
|
|
if not os.path.exists(OUTPUT_DIR): |
|
os.makedirs(OUTPUT_DIR) |
|
|
|
|
|
def save_uploaded_file(uploaded_file): |
|
file_path = os.path.join(TEMP_DIR, uploaded_file.name) |
|
with open(file_path, 'wb') as f: |
|
f.write(uploaded_file.getbuffer()) |
|
return file_path |
|
|
|
|
|
load_dotenv() |
|
|
|
groq_api_key = os.environ.get("GROQ_API_KEY") |
|
|
|
|
|
def main(): |
|
|
|
set_custom_css() |
|
|
|
|
|
if 'edited_code' not in st.session_state: |
|
st.session_state['edited_code'] = "" |
|
|
|
|
|
if 'code_generated' not in st.session_state: |
|
st.session_state['code_generated'] = False |
|
|
|
|
|
st.markdown(""" |
|
<div class="header"> |
|
<h1>AutoTabML</h1> |
|
<p>Automated Machine Learning Code Generation for Tabluar Data</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.title('LLM Model') |
|
model = st.sidebar.selectbox( |
|
'Model', |
|
["llama3-70b-8192"] |
|
) |
|
|
|
|
|
llm = initialize_llm(model) |
|
|
|
|
|
|
|
|
|
user_question = st.text_area("Describe your ML problem:", key="user_question") |
|
uploaded_file = st.file_uploader("Upload a sample .csv of your data", key="uploaded_file") |
|
try: |
|
file_name = uploaded_file.name |
|
except: |
|
file_name = "dataset.csv" |
|
|
|
|
|
agents = initialize_agents(llm,file_name,TEMP_DIR) |
|
|
|
if uploaded_file: |
|
try: |
|
file_path = save_uploaded_file(uploaded_file) |
|
df = pd.read_csv(uploaded_file) |
|
st.write("Data successfully uploaded:") |
|
st.dataframe(df.head()) |
|
data_upload = True |
|
except Exception as e: |
|
st.error(f"Error reading the file: {e}") |
|
data_upload = False |
|
else: |
|
df = None |
|
data_upload = False |
|
|
|
|
|
if st.button('Process'): |
|
tasks = create_tasks("Process",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], None, agents) |
|
with st.spinner('Processing...'): |
|
crew = Crew( |
|
agents=list(agents.values()), |
|
tasks=tasks, |
|
verbose=2 |
|
) |
|
|
|
result = crew.kickoff() |
|
|
|
if result: |
|
code = result.strip("```") |
|
try: |
|
filt_idx = code.index("```") |
|
code = code[:filt_idx] |
|
except: |
|
pass |
|
st.session_state['edited_code'] = code |
|
st.session_state['code_generated'] = True |
|
|
|
st.session_state['edited_code'] = st_ace.st_ace( |
|
value=st.session_state['edited_code'], |
|
language='python', |
|
theme='monokai', |
|
keybinding='vscode', |
|
min_lines=20, |
|
max_lines=50 |
|
) |
|
|
|
if st.session_state['code_generated']: |
|
|
|
suggestion = st.text_area("Suggest modifications to the generated code (optional):", key="suggestion") |
|
if st.button('Modify'): |
|
if st.session_state['edited_code'] and suggestion: |
|
tasks = create_tasks("Modify",user_question,file_name, data_upload, df, suggestion, st.session_state['edited_code'], None, agents) |
|
with st.spinner('Modifying code...'): |
|
crew = Crew( |
|
agents=list(agents.values()), |
|
tasks=tasks, |
|
verbose=2 |
|
) |
|
|
|
result = crew.kickoff() |
|
|
|
if result: |
|
code = result.strip("```") |
|
try: |
|
filter_idx = code.index("```") |
|
code = code[:filter_idx] |
|
except: |
|
pass |
|
st.session_state['edited_code'] = code |
|
|
|
st.write("Modified code:") |
|
st.session_state['edited_code']= st_ace.st_ace( |
|
value=st.session_state['edited_code'], |
|
language='python', |
|
theme='monokai', |
|
keybinding='vscode', |
|
min_lines=20, |
|
max_lines=50 |
|
) |
|
|
|
debugger = st.text_area("Paste error message here for debugging (optional):", key="debugger") |
|
if st.button('Debug'): |
|
if st.session_state['edited_code'] and debugger: |
|
tasks = create_tasks("Debug",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], debugger, agents) |
|
with st.spinner('Debugging code...'): |
|
crew = Crew( |
|
agents=list(agents.values()), |
|
tasks=tasks, |
|
verbose=2 |
|
) |
|
|
|
result = crew.kickoff() |
|
|
|
if result: |
|
code = result.strip("```") |
|
try: |
|
filter_idx = code.index("```") |
|
code = code[:filter_idx] |
|
except: |
|
pass |
|
st.session_state['edited_code'] = code |
|
|
|
st.write("Debugged code:") |
|
st.session_state['edited_code'] = st_ace.st_ace( |
|
value=st.session_state['edited_code'], |
|
language='python', |
|
theme='monokai', |
|
keybinding='vscode', |
|
min_lines=20, |
|
max_lines=50 |
|
) |
|
|
|
if st.button('Run'): |
|
output = io.StringIO() |
|
with contextlib.redirect_stdout(output): |
|
try: |
|
globals().update({'dataset': df}) |
|
final_code = st.session_state["edited_code"] |
|
|
|
with st.expander("Final Code"): |
|
st.code(final_code, language='python') |
|
|
|
exec(final_code, globals()) |
|
result = output.getvalue() |
|
success = True |
|
except Exception as e: |
|
result = str(e) |
|
success = False |
|
|
|
st.subheader('Output:') |
|
st.text(result) |
|
|
|
figs = [manager.canvas.figure for manager in plt._pylab_helpers.Gcf.get_all_fig_managers()] |
|
if figs: |
|
st.subheader('Generated Plots:') |
|
for fig in figs: |
|
st.pyplot(fig) |
|
|
|
if success: |
|
st.success("Code executed successfully!") |
|
else: |
|
st.error("Code execution failed! Waiting for debugging input...") |
|
|
|
|
|
with st.sidebar: |
|
st.header('Output_dir :') |
|
files = glob.glob(os.path.join(OUTPUT_DIR, '*')) |
|
for file in files: |
|
if os.path.isfile(file): |
|
with open(file, 'rb') as f: |
|
st.download_button(label=f'Download {os.path.basename(file)}', data=f, file_name=os.path.basename(file)) |
|
|
|
|
|
|
|
|
|
def set_custom_css(): |
|
st.markdown(""" |
|
<style> |
|
body { |
|
background: #0e0e0e; |
|
color: #e0e0e0; |
|
font-family: 'Roboto', sans-serif; |
|
} |
|
.header { |
|
background: linear-gradient(135deg, #6e3aff, #b839ff); |
|
padding: 10px; |
|
border-radius: 10px; |
|
} |
|
.header h1, .header p { |
|
color: white; |
|
text-align: center; |
|
} |
|
.stButton button { |
|
background-color: #b839ff; |
|
color: white; |
|
border-radius: 10px; |
|
font-size: 16px; |
|
padding: 10px 20px; |
|
} |
|
.stButton button:hover { |
|
background-color: #6e3aff; |
|
color: #e0e0e0; |
|
} |
|
.spinner { |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
def initialize_llm(model): |
|
return ChatGroq( |
|
temperature=0, |
|
groq_api_key=groq_api_key, |
|
model_name=model |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |