multimodalart HF staff commited on
Commit
ee2e0b7
1 Parent(s): 6845a5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -7,25 +7,24 @@ from mario_gpt.prompter import Prompter
7
  from mario_gpt.lm import MarioLM
8
  from mario_gpt.utils import view_level, convert_level_to_png
9
 
10
- import os
11
- import subprocess
12
 
13
- from pyngrok import ngrok
 
14
 
15
  mario_lm = MarioLM()
16
  device = torch.device('cuda')
17
  mario_lm = mario_lm.to(device)
18
  TILE_DIR = "data/tiles"
19
 
20
- subprocess.Popen(["python3","-m","http.server","7861"])
21
- ngrok.set_auth_token(os.environ.get('NGROK_TOKEN'))
22
- http_tunnel = ngrok.connect(7861,bind_tls=True)
23
 
24
  def make_html_file(generated_level):
25
  level_text = f"""{'''
26
  '''.join(view_level(generated_level,mario_lm.tokenizer))}"""
27
  unique_id = uuid.uuid1()
28
- with open(f"demo-{unique_id}.html", 'w', encoding='utf-8') as f:
29
  f.write(f'''<!DOCTYPE html>
30
  <html lang="en">
31
 
@@ -42,7 +41,7 @@ def make_html_file(generated_level):
42
  cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
43
  }});
44
  cheerpjCreateDisplay(512, 500);
45
- cheerpjRunJar("/app/mario.jar");
46
  </script>
47
  </html>''')
48
  return f"demo-{unique_id}.html"
@@ -61,9 +60,9 @@ def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size =
61
  filename = make_html_file(generated_level)
62
  img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
63
 
64
- gradio_html = f'''<div style="border: 2px solid;">
65
- <iframe width=512 height=512 style="margin: 0 auto" src="{http_tunnel.public_url}/{filename}"></iframe>
66
- <p style="text-align:center">Press the arrow keys to move. Press <code>s</code> to jump and <code>a</code> to shoot flowers</p>
67
  </div>'''
68
  return [img, gradio_html]
69
 
@@ -72,16 +71,16 @@ with gr.Blocks() as demo:
72
  [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
73
  ''')
74
  with gr.Tabs():
75
- with gr.TabItem("Type prompt"):
76
- text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
77
  with gr.TabItem("Compose prompt"):
78
  with gr.Row():
79
- pipes = gr.Radio(["no", "little", "some", "many"], label="pipes")
80
- enemies = gr.Radio(["no", "little", "some", "many"], label="enemies")
81
  with gr.Row():
82
- blocks = gr.Radio(["little", "some", "many"], label="blocks")
83
- elevation = gr.Radio(["low", "high"], label="elevation")
84
-
 
 
85
  with gr.Accordion(label="Advanced settings", open=False):
86
  temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
87
  level_size = gr.Number(value=1399, precision=0, label="level_size")
@@ -104,4 +103,7 @@ with gr.Blocks() as demo:
104
  fn=generate,
105
  cache_examples=True,
106
  )
107
- demo.launch()
 
 
 
 
7
  from mario_gpt.lm import MarioLM
8
  from mario_gpt.utils import view_level, convert_level_to_png
9
 
10
+ from fastapi import FastAPI
11
+ from fastapi.staticfiles import StaticFiles
12
 
13
+ import os
14
+ import uvicorn
15
 
16
  mario_lm = MarioLM()
17
  device = torch.device('cuda')
18
  mario_lm = mario_lm.to(device)
19
  TILE_DIR = "data/tiles"
20
 
21
+ app = FastAPI()
 
 
22
 
23
  def make_html_file(generated_level):
24
  level_text = f"""{'''
25
  '''.join(view_level(generated_level,mario_lm.tokenizer))}"""
26
  unique_id = uuid.uuid1()
27
+ with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f:
28
  f.write(f'''<!DOCTYPE html>
29
  <html lang="en">
30
 
 
41
  cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
42
  }});
43
  cheerpjCreateDisplay(512, 500);
44
+ cheerpjRunJar("/app/static/mario.jar");
45
  </script>
46
  </html>''')
47
  return f"demo-{unique_id}.html"
 
60
  filename = make_html_file(generated_level)
61
  img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
62
 
63
+ gradio_html = f'''<div>
64
+ <iframe width=512 height=512 style="margin: 0 auto" src="static/{filename}"></iframe>
65
+ <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
66
  </div>'''
67
  return [img, gradio_html]
68
 
 
71
  [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
72
  ''')
73
  with gr.Tabs():
 
 
74
  with gr.TabItem("Compose prompt"):
75
  with gr.Row():
76
+ pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?")
77
+ enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?")
78
  with gr.Row():
79
+ blocks = gr.Radio(["little", "some", "many"], label="How many blocks?")
80
+ elevation = gr.Radio(["low", "high"], label="Elevation?")
81
+ with gr.TabItem("Type prompt"):
82
+ text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
83
+
84
  with gr.Accordion(label="Advanced settings", open=False):
85
  temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
86
  level_size = gr.Number(value=1399, precision=0, label="level_size")
 
103
  fn=generate,
104
  cache_examples=True,
105
  )
106
+
107
+ app.mount("/static", StaticFiles(directory="static", html=True), name="static")
108
+ app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/")
109
+ uvicorn.run(app, host="0.0.0.0", port=7860)