GenSim / gensim /agent.py
LeroyWaa's picture
add gensim code
8fc2b4e
raw
history blame contribute delete
No virus
7.88 kB
import numpy as np
import os
import IPython
import random
import json
import traceback
import pybullet as p
from gensim.utils import (
save_text,
add_to_txt,
extract_code,
extract_dict,
extract_list,
extract_assets,
format_dict_prompt,
sample_list_reference,
generate_feedback,
)
class Agent:
"""
class that design new tasks and codes for simulation environments
"""
def __init__(self, cfg, memory):
self.cfg = cfg
self.model_output_dir = cfg["model_output_dir"]
self.prompt_folder = f"prompts/{cfg['prompt_folder']}"
self.memory = memory
self.chat_log = memory.chat_log
self.use_template = cfg['use_template']
def propose_task(self, proposed_task_names):
"""Language descriptions for the task"""
add_to_txt(self.chat_log, "================= Task and Asset Design!", with_print=True)
if self.use_template:
task_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task.txt").read()
task_asset_replacement_str = format_dict_prompt(self.memory.online_asset_buffer, self.cfg['task_asset_candidate_num'])
task_prompt_text = task_prompt_text.replace("TASK_ASSET_PROMPT", task_asset_replacement_str)
task_desc_replacement_str = format_dict_prompt(self.memory.online_task_buffer, self.cfg['task_description_candidate_num'])
print("prompt task description candidates:")
print(task_desc_replacement_str)
task_prompt_text = task_prompt_text.replace("TASK_DESCRIPTION_PROMPT", task_desc_replacement_str)
if len(self.cfg['target_task_name']) > 0:
task_prompt_text = task_prompt_text.replace("TARGET_TASK_NAME", self.cfg['target_task_name'])
# print("Template Task PROMPT: ", task_prompt_text)
else:
task_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task.txt").read()
# maximum number
print("online_task_buffer size:", len(self.memory.online_task_buffer))
total_tasks = self.memory.online_task_buffer
MAX_NUM = 20
if len(total_tasks) > MAX_NUM:
total_tasks = dict(random.sample(total_tasks.items(), MAX_NUM))
task_prompt_text = task_prompt_text.replace("PAST_TASKNAME_TEMPLATE", format_dict_prompt(total_tasks))
res = generate_feedback(
task_prompt_text,
temperature=self.cfg["gpt_temperature"],
interaction_txt=self.chat_log,
)
# Extract dictionary for task name, descriptions, and assets
task_def = extract_dict(res, prefix="new_task")
try:
exec(task_def, globals())
self.new_task = new_task
return new_task
except:
self.new_task = {"task-name": "dummy", "assets-used": [], "task_descriptions": ""}
print(str(traceback.format_exc()))
return self.new_task
def propose_assets(self):
"""Asset Generation. Not used for now."""
if os.path.exists(f"{self.prompt_folder}/cliport_prompt_asset_template.txt"):
add_to_txt(self.chat_log, "================= Asset Generation!", with_print=True)
asset_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_asset_template.txt").read()
if self.use_template:
asset_prompt_text = asset_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
asset_prompt_text = asset_prompt_text.replace("ASSET_STRING_TEMPLATE", str(self.new_task["assets-used"]))
print("Template Asset PROMPT: ", asset_prompt_text)
res = generate_feedback(asset_prompt_text, temperature=0, interaction_txt=self.chat_log)
print("Save asset to:", self.model_output_dir, task_name + "_asset_output")
save_text(self.model_output_dir, f'{self.new_task["task-name"]}_asset_output', res)
asset_list = extract_assets(res)
# save_urdf(asset_list)
else:
asset_list = {}
return asset_list
def api_review(self):
"""review the task api"""
if os.path.exists(f"{self.prompt_folder}/cliport_prompt_api_template.txt"):
add_to_txt(
self.chat_log, "================= API Preview!", with_print=True)
api_prompt_text = open(
f"{self.prompt_folder}/cliport_prompt_api_template.txt").read()
if "task-name" in self.new_task:
api_prompt_text = api_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
api_prompt_text = api_prompt_text.replace("TASK_STRING_TEMPLATE", str(self.new_task))
res = generate_feedback(
api_prompt_text, temperature=0, interaction_txt=self.chat_log)
def template_reference_prompt(self):
""" select which code reference to reference """
if os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt"):
self.chat_log = add_to_txt(self.chat_log, "================= Code Reference!", with_print=True)
code_reference_question = open(f'{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt').read()
code_reference_question = code_reference_question.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
code_reference_question = code_reference_question.replace("TASK_CODE_LIST_TEMPLATE", str(list(self.memory.online_code_buffer.keys())))
code_reference_question = code_reference_question.replace("TASK_STRING_TEMPLATE", str(self.new_task))
res = generate_feedback(code_reference_question, temperature=0., interaction_txt=self.chat_log)
code_reference_cmd = extract_list(res, prefix='code_reference')
exec(code_reference_cmd, globals())
task_code_reference_replace_prompt = ''
for key in code_reference:
if key in self.memory.online_code_buffer:
task_code_reference_replace_prompt += f'```\n{self.memory.online_code_buffer[key]}\n```\n\n'
else:
print("missing task reference code:", key)
else:
task_code_reference_replace_prompt = sample_list_reference(base_task_codes, sample_num=cfg['task_code_candidate_num'])
# print("Template Reference Code PROMPT: ", task_code_reference_replace_prompt)
return task_code_reference_replace_prompt
def implement_task(self):
"""Generate Code for the task"""
code_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_code_split_template.txt").read()
code_prompt_text = code_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
if self.use_template or os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt"):
task_code_reference_replace_prompt = self.template_reference_prompt()
code_prompt_text = code_prompt_text.replace("TASK_CODE_REFERENCE_TEMPLATE", task_code_reference_replace_prompt)
elif os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_split_template.txt"):
self.chat_log = add_to_txt(self.chat_log, "================= Code Generation!", with_print=True)
code_prompt_text = code_prompt_text.replace("TASK_STRING_TEMPLATE", str(self.new_task))
res = generate_feedback(
code_prompt_text, temperature=0, interaction_txt=self.chat_log)
code, task_name = extract_code(res)
print("Save code to:", self.model_output_dir, task_name + "_code_output")
save_text(self.model_output_dir, task_name + "_code_output", code)
if len(task_name) == 0:
print("empty task name:", task_name)
return None
return code, task_name