GenSim / gensim /memory.py
LeroyWaa's picture
add gensim code
8fc2b4e
raw
history blame contribute delete
No virus
5.17 kB
import numpy as np
import os
import IPython
import random
import json
from gensim.utils import save_text
class Memory:
"""
class that maintains a buffer of generated tasks and codes
"""
def __init__(self, cfg):
self.prompt_folder = f"prompts/{cfg['prompt_folder']}"
self.data_path = cfg["prompt_data_path"]
self.cfg = cfg
# a chat history is a list of strings
self.chat_log = []
self.online_task_buffer = {}
self.online_code_buffer = {}
self.online_asset_buffer = {}
# directly load current offline memory into online memory
base_tasks, base_assets, base_task_codes = self.load_offline_memory()
self.online_task_buffer.update(base_tasks)
self.online_asset_buffer.update(base_assets)
# load each code file
for task_file in base_task_codes:
# the original cliport task path
if os.path.exists("cliport/tasks/" + task_file):
self.online_code_buffer[task_file] = open("cliport/tasks/" + task_file).read()
# the generated cliport task path
elif os.path.exists("cliport/generated_tasks/" + task_file):
self.online_code_buffer[task_file] = open("cliport/generated_tasks/" + task_file).read()
print(f"load {len(self.online_code_buffer)} tasks for memory from offline to online:")
cache_embedding_path = "outputs/task_cache_embedding.npz"
if os.path.exists(cache_embedding_path):
print("task code embeding:", cache_embedding_path)
self.task_code_embedding = np.load(cache_embedding_path)
def save_run(self, new_task):
"""save chat history and potentially save base memory"""
print("save all interaction to :", f'{new_task["task-name"]}_full_output')
unroll_chatlog = ''
for chat in self.chat_log:
unroll_chatlog += chat
save_text(
self.cfg['model_output_dir'], f'{new_task["task-name"]}_full_output', unroll_chatlog
)
def save_task_to_online(self, new_task, code):
"""(not dumping the task offline). save the task information for online bootstrapping."""
self.online_task_buffer[new_task['task-name']] = new_task
code_file_name = new_task["task-name"].replace("-", "_") + ".py"
# code file name: actual code in contrast to offline code files format.
self.online_code_buffer[code_file_name] = code
def save_task_to_offline(self, new_task, code):
"""save the current task descriptions, assets, and code, if it passes reflection and environment test"""
generated_task_code_path = os.path.join(
self.cfg["prompt_data_path"], "generated_task_codes.json"
)
generated_task_codes = json.load(open(generated_task_code_path))
new_file_path = new_task["task-name"].replace("-", "_") + ".py"
if new_file_path not in generated_task_codes:
generated_task_codes.append(new_file_path)
python_file_path = "cliport/generated_tasks/" + new_file_path
print(f"save {new_task['task-name']} to ", python_file_path)
with open(python_file_path, "w",
) as fhandle:
fhandle.write(code)
with open(generated_task_code_path, "w") as outfile:
json.dump(generated_task_codes, outfile, indent=4)
else:
print(f"{new_file_path}.py already exists.")
# save task descriptions
generated_task_path = os.path.join(
self.cfg["prompt_data_path"], "generated_tasks.json"
)
generated_tasks = json.load(open(generated_task_path))
generated_tasks[new_task["task-name"]] = new_task
with open(generated_task_path, "w") as outfile:
json.dump(generated_tasks, outfile, indent=4)
def load_offline_memory(self):
"""get the current task descriptions, assets, and code"""
base_task_path = os.path.join(self.data_path, "base_tasks.json")
base_asset_path = os.path.join(self.data_path, "base_assets.json")
base_task_code_path = os.path.join(self.data_path, "base_task_codes.json")
base_tasks = json.load(open(base_task_path))
base_assets = json.load(open(base_asset_path))
base_task_codes = json.load(open(base_task_code_path))
if self.cfg["load_memory"]:
generated_task_path = os.path.join(self.data_path, "generated_tasks.json")
generated_asset_path = os.path.join(self.data_path, "generated_assets.json")
generated_task_code_path = os.path.join(self.data_path, "generated_task_codes.json")
print("original base task num:", len(base_tasks))
base_tasks.update(json.load(open(generated_task_path)))
# base_assets.update(json.load(open(generated_asset_path)))
for task in json.load(open(generated_task_code_path)):
if task not in base_task_codes:
base_task_codes.append(task)
print("current base task num:", len(base_tasks))
return base_tasks, base_assets, base_task_codes