import gradio as gr import torch from mario_gpt.dataset import MarioDataset from mario_gpt.prompter import Prompter from mario_gpt.lm import MarioLM from mario_gpt.utils import view_level, convert_level_to_png mario_lm = MarioLM() device = torch.device('cuda') mario_lm = mario_lm.to(device) TILE_DIR = "data/tiles" def update(prompt): prompts = [prompt] generated_level = mario_lm.sample( prompts=prompts, num_steps=1399, temperature=2.0, use_tqdm=True ) img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] return img with gr.Blocks() as demo: gr.Markdown("## Demo for ['MarioGPT: Open-Ended Text2Level Generation through Large Language Models'](https://github.com/shyamsn97/mario-gpt)") prompt = gr.Textbox(label="Enter your MarioGPT prompt") level_image = gr.Image() btn = gr.Button("Generate level") btn.click(fn=update, inputs=prompt, outputs=level_image) gr.Examples( examples=["many pipes, many enemies, some blocks, high elevation", "little pipes, little enemies, many blocks, high elevation", "many pipes, some enemies", "no pipes, no enemies, many blocks"], inputs=prompt, outputs=level_image, fn=update, cache_examples=True, ) demo.launch()