Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import random | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from scipy import stats | |
from transformers import pipeline | |
from mario_gpt.dataset import MarioDataset | |
from mario_gpt.utils import view_level | |
STATISTICS = { | |
"enemy": np.array([1.0, 3.0, 7.0]), | |
"pipe": np.array([0.0, 2.0, 5.0]), | |
"block": np.array([50.0, 75.0, 176.0]), | |
} | |
FEATURE_EXTRACTION_MODEL = "facebook/bart-base" | |
class Prompter: | |
def __init__( | |
self, | |
level_tokenizer, | |
prompter_model: str = FEATURE_EXTRACTION_MODEL, | |
use_raw_counts: bool = False, | |
statistics: Optional[Dict[str, Any]] = None, | |
): | |
self.prompter_model = prompter_model | |
self.feature_extraction = pipeline( | |
"feature-extraction", | |
model=prompter_model, | |
tokenizer=prompter_model, | |
framework="pt", | |
) | |
self.level_tokenizer = level_tokenizer | |
self.use_raw_counts = use_raw_counts | |
self.statistics = statistics | |
if statistics is None: | |
self.statistics = STATISTICS | |
def pipe_thresholds(self) -> Tuple[List[int], List[str]]: | |
thresholds = self.statistics["pipe"] | |
keywords = ["no", "little", "some", "many"] | |
return thresholds, keywords | |
def enemy_thresholds(self) -> Tuple[List[int], List[str]]: | |
thresholds = self.statistics["enemy"] | |
keywords = ["no", "little", "some", "many"] | |
return thresholds, keywords | |
def block_thresholds(self) -> Tuple[List[int], List[str]]: | |
thresholds = self.statistics["block"] | |
keywords = ["little", "little", "some", "many"] | |
return thresholds, keywords | |
def count_pipes(self, flattened_level: str) -> int: | |
return flattened_level.count("<>") | |
def count_enemies(self, flattened_level: str) -> int: | |
return flattened_level.count("E") + flattened_level.count("B") | |
def count_blocks(self, flattened_level: str) -> int: | |
return np.sum([flattened_level.count(char) for char in ["X", "S", "?", "Q"]]) | |
def _flatten_level(self, string_level: List[str]) -> str: | |
return "".join(string_level) | |
def pipe_prompt(self, flattened_level: str, level: str) -> str: | |
count = self.count_pipes(flattened_level) | |
keyword = f"{count}" | |
if not self.use_raw_counts: | |
thresholds, keywords = self.pipe_thresholds | |
threshold = np.digitize(count, thresholds, right=True) | |
keyword = keywords[threshold] | |
return f"{keyword} pipes", keyword | |
def enemy_prompt(self, flattened_level: str, level: str) -> str: | |
count = self.count_enemies(flattened_level) | |
keyword = f"{count}" | |
if not self.use_raw_counts: | |
thresholds, keywords = self.enemy_thresholds | |
threshold = np.digitize(count, thresholds, right=True) | |
keyword = keywords[threshold] | |
return f"{keyword} enemies", keyword | |
def block_prompt(self, flattened_level: str, level: str) -> str: | |
count = self.count_blocks(flattened_level) | |
keyword = f"{count}" | |
if not self.use_raw_counts: | |
thresholds, keywords = self.block_thresholds | |
threshold = np.digitize(count, thresholds, right=True) | |
keyword = keywords[threshold] | |
return f"{keyword} blocks", keyword | |
def elevation_prompt(self, flattened_level: str, level: str): | |
top_levels = level[:6] # elevation 8 and up | |
for t in top_levels: | |
if "X" in t or "<" in t or ">" in t: | |
return "high elevation", "high" | |
return "low elevation", "low" | |
def output_hidden(self, prompt: str, device: torch.device = torch.device("cpu")): | |
# Reducing along the first dimension to get a 768 dimensional array | |
return ( | |
self.feature_extraction(prompt, return_tensors="pt")[0] | |
.mean(0) | |
.to(device) | |
.view(1, -1) | |
) | |
def dataset_statistics(self, dataset: MarioDataset): | |
enemy_counts = [] | |
pipe_counts = [] | |
block_counts = [] | |
for i in range(len(dataset)): | |
level, _ = dataset[i] | |
str_level = self._flatten_level(view_level(level, dataset.tokenizer)) | |
enemy_count = self.count_enemies(str_level) | |
pipe_count = self.count_pipes(str_level) | |
block_count = self.count_blocks(str_level) | |
enemy_counts.append(enemy_count) | |
pipe_counts.append(pipe_count) | |
block_counts.append(block_count) | |
d = {"enemy": {}, "pipe": {}, "block": {}} | |
d["enemy"] = stats.mstats.mquantiles(enemy_counts, [0.33, 0.66, 0.95]) | |
d["pipe"] = stats.mstats.mquantiles(pipe_counts, [0.33, 0.66, 0.95]) | |
d["block"] = stats.mstats.mquantiles(block_counts, [0.33, 0.66, 0.95]) | |
return d | |
def __call__( | |
self, level: torch.Tensor = None, sample_prompt: bool = False | |
) -> Union[str, torch.Tensor]: | |
device: torch.device = torch.device("cpu") | |
if not sample_prompt: | |
if level is None: | |
raise ValueError("Level must be provided if sample_prompt is not true!") | |
str_level = view_level(level, self.level_tokenizer) | |
flattened_level = self._flatten_level(str_level) | |
pipe_prompt, _ = self.pipe_prompt(flattened_level, str_level) | |
enemy_prompt, _ = self.enemy_prompt(flattened_level, str_level) | |
block_prompt, _ = self.block_prompt(flattened_level, str_level) | |
elevation_prompt, _ = self.elevation_prompt(flattened_level, str_level) | |
device = level.device | |
else: | |
str_level = None | |
pipe_prompt = random.choice(["no", "little", "some", "many"]) + " pipes" | |
enemy_prompt = random.choice(["no", "little", "some", "many"]) + " enemies" | |
block_prompt = ( | |
random.choice(["little", "little", "some", "many"]) + " blocks" | |
) # levels always have blocks | |
elevation_prompt = ( | |
random.choice(["low", "high"]) + " elevation" | |
) # levels always have blocks | |
prompt_dict = { | |
"pipe": pipe_prompt, | |
"enemy": enemy_prompt, | |
"block": block_prompt, | |
"elevation_prompt": elevation_prompt, | |
} | |
prompt = f"{pipe_prompt}, {enemy_prompt}, {block_prompt}, {elevation_prompt}" | |
hidden = self.output_hidden(prompt, device=device) | |
return prompt, hidden, prompt_dict, str_level | |