Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import uuid | |
| from mario_gpt.dataset import MarioDataset | |
| from mario_gpt.prompter import Prompter | |
| from mario_gpt.lm import MarioLM | |
| from mario_gpt.utils import view_level, convert_level_to_png | |
| from fastapi import FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| import os | |
| import uvicorn | |
| mario_lm = MarioLM() | |
| device = torch.device('cuda') | |
| mario_lm = mario_lm.to(device) | |
| TILE_DIR = "data/tiles" | |
| app = FastAPI() | |
| def make_html_file(generated_level): | |
| level_text = f"""{''' | |
| '''.join(view_level(generated_level,mario_lm.tokenizer))}""" | |
| unique_id = uuid.uuid1() | |
| with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f: | |
| f.write(f'''<!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <title>Mario Game</title> | |
| <script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script> | |
| </head> | |
| <body> | |
| </body> | |
| <script> | |
| cheerpjInit().then(function () {{ | |
| cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`); | |
| }}); | |
| cheerpjCreateDisplay(512, 500); | |
| cheerpjRunJar("/app/static/mario.jar"); | |
| </script> | |
| </html>''') | |
| return f"demo-{unique_id}.html" | |
| def trim_level(level): | |
| mod = level.shape[-1] % 14 | |
| if mod > 0: | |
| return level[:, :-mod] | |
| return level | |
| def reset_state(seed_state): | |
| length = len(seed_state) | |
| print(f"Resetting state with {length} levels!") | |
| for _ in range(length): | |
| seed_state.pop() | |
| def _generate_level(prompts, seed, level_size, temperature): | |
| print(f"Using prompts: {prompts}") | |
| generated_levels = mario_lm.sample( | |
| prompts=prompts, | |
| num_steps=level_size, | |
| temperature=temperature, | |
| use_tqdm=True, | |
| seed = seed | |
| ) | |
| generated_levels = trim_level(generated_levels) | |
| return generated_levels | |
| def _make_gradio_html(level): | |
| filename = make_html_file(level) | |
| gradio_html = f'''<div> | |
| <iframe width=512 height=512 style="margin: 0 auto" src="static/{filename}"></iframe> | |
| <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p> | |
| </div>''' | |
| return gradio_html | |
| def initialize_generate(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400): | |
| prompts = [f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"] | |
| generated_levels = _generate_level(prompts, None, level_size, temperature) | |
| level = generated_levels.squeeze().detach().cpu() | |
| img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0] | |
| return [img, _make_gradio_html(level)] | |
| def generate_choices(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400, prompt = "", seed_state = []): | |
| NUM_SAMPLES = 2 | |
| if prompt == "": | |
| prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" | |
| prompts = [prompt] * NUM_SAMPLES | |
| seed = None | |
| if len(seed_state) > 0: | |
| seed = torch.cat(seed_state).squeeze()[-48*14:].view(1, -1).repeat(NUM_SAMPLES, 1) # context length | |
| generated_levels = _generate_level(prompts, seed, level_size, temperature).detach().cpu().squeeze() | |
| level_choices = [generated_level[-level_size:] for generated_level in generated_levels] | |
| level_choice_images = [convert_level_to_png(generated_level[-level_size:], TILE_DIR, mario_lm.tokenizer)[0] for generated_level in generated_levels] | |
| # level choices + separate images | |
| return [level_choices, *level_choice_images] | |
| def update_level_state(choice_id, level_choices, seed_state): | |
| num_choice = int(choice_id) | |
| level_choice = level_choices[num_choice] | |
| # append level choice to seed state | |
| seed_state.append(level_choice) | |
| # get new level from concatenation | |
| level = torch.cat(seed_state).squeeze() | |
| # final image and gradio html | |
| img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0] | |
| gradio_html = _make_gradio_html(level) | |
| # return img, gradio html, seed state, level_choice, choice_image_1, choice_image_2, current_level_size | |
| return img, gradio_html, seed_state, None, None, None, level.shape[-1] | |
| with gr.Blocks().queue() as demo: | |
| gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models | |
| [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)] | |
| ''') | |
| with gr.Tabs(): | |
| with gr.TabItem("Compose prompt"): | |
| with gr.Row(): | |
| pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?") | |
| enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?") | |
| with gr.Row(): | |
| blocks = gr.Radio(["little", "some", "many"], label="How many blocks?") | |
| elevation = gr.Radio(["low", "high"], label="Elevation?") | |
| with gr.TabItem("Type prompt"): | |
| text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'") | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations") | |
| level_size = gr.Number(value=1400, precision=0, label="level_size") | |
| generate_btn = gr.Button("Generate Level") | |
| reset_btn = gr.Button("Reset Level") | |
| with gr.Row(): | |
| with gr.Box(): | |
| level_play = gr.HTML() | |
| level_image = gr.Image(label="Current Level") | |
| with gr.Box(): | |
| with gr.Column(): | |
| level_choice1_image = gr.Image(label="Sample Choice 1") | |
| level_choice1_btn = gr.Button("Sample Choice 1") | |
| with gr.Column(): | |
| level_choice2_image = gr.Image(label="Sample Choice 2") | |
| level_choice2_btn = gr.Button("Sample Choice 2") | |
| current_level_size = gr.Number(0, visible=True, label="Current Level Size") | |
| seed_state = gr.State([]) | |
| state_choices = gr.State(None) | |
| image_choice_1_id = gr.Number(0, visible=False) | |
| image_choice_2_id = gr.Number(1, visible=False) | |
| # choice buttons | |
| level_choice1_btn.click(fn=update_level_state, inputs=[image_choice_1_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size]) | |
| level_choice2_btn.click(fn=update_level_state, inputs=[image_choice_2_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size]) | |
| # generate_btn | |
| generate_btn.click(fn=generate_choices, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt, seed_state], outputs=[state_choices, level_choice1_image, level_choice2_image]) | |
| # reset btn | |
| reset_btn.click(fn=reset_state, inputs=[seed_state], outputs=[]) | |
| gr.Examples( | |
| examples=[ | |
| ["many", "many", "some", "high", 2.0], | |
| ["no", "some", "many", "high", 2.0], | |
| ["many", "many", "little", "low", 2.4], | |
| ["no", "no", "many", "high", 2.8], | |
| ], | |
| inputs=[pipes, enemies, blocks, elevation, temperature, level_size], | |
| outputs=[level_image, level_play], | |
| fn=initialize_generate, | |
| cache_examples=True, | |
| ) | |
| app.mount("/static", StaticFiles(directory="static", html=True), name="static") | |
| app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |